mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives. PiperOrigin-RevId: 441474701
This commit is contained in:
parent
ad8e6ada4e
commit
cb4abe754a
@ -231,10 +231,10 @@ custom_vmap_p.def_impl(custom_vmap_impl)
|
||||
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
|
||||
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
|
||||
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
|
||||
xla.register_initial_style_primitive(custom_vmap_p)
|
||||
xla.register_translation(custom_vmap_p,
|
||||
xla.lower_fun(custom_vmap_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
multiple_results=True))
|
||||
mlir.register_lowering(custom_vmap_p, mlir.lower_fun(
|
||||
custom_vmap_impl, multiple_results=True))
|
||||
|
||||
|
@ -397,11 +397,11 @@ def _custom_jvp_call_jaxpr_vmap(
|
||||
return batched_outs, out_dims
|
||||
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap
|
||||
|
||||
xla.register_initial_style_primitive(custom_jvp_call_jaxpr_p)
|
||||
xla.register_translation(
|
||||
custom_jvp_call_jaxpr_p,
|
||||
xla.lower_fun(_custom_jvp_call_jaxpr_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
multiple_results=True))
|
||||
|
||||
# If a (multi)linear function is defined with a custom jvp, then
|
||||
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
|
||||
@ -768,11 +768,11 @@ def _custom_vjp_call_jaxpr_vmap(
|
||||
return batched_outs, out_dims
|
||||
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap
|
||||
|
||||
xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
|
||||
xla.register_translation(
|
||||
custom_vjp_call_jaxpr_p,
|
||||
xla.lower_fun(_custom_vjp_call_jaxpr_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
multiple_results=True))
|
||||
|
||||
batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
xla.register_translation(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
|
||||
@ -1164,10 +1164,10 @@ linear_call_p.multiple_results = True
|
||||
linear_call_p.def_impl(_linear_call_impl)
|
||||
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
|
||||
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
|
||||
xla.register_initial_style_primitive(linear_call_p)
|
||||
xla.register_translation(linear_call_p,
|
||||
xla.lower_fun(_linear_call_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
multiple_results=True))
|
||||
mlir.register_lowering(linear_call_p, mlir.lower_fun(
|
||||
_linear_call_impl, multiple_results=True))
|
||||
|
||||
|
@ -225,8 +225,8 @@ ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
|
||||
mlir.register_lowering(
|
||||
custom_transpose_p,
|
||||
mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
|
||||
xla.register_initial_style_primitive(custom_transpose_p)
|
||||
xla.register_translation(
|
||||
custom_transpose_p,
|
||||
xla.lower_fun(
|
||||
custom_transpose_lowering, new_style=True, multiple_results=True),
|
||||
initial_style=True)
|
||||
custom_transpose_lowering, new_style=True, multiple_results=True))
|
||||
|
@ -620,8 +620,8 @@ while_p.def_impl(partial(xla.apply_primitive, while_p))
|
||||
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
|
||||
ad.primitive_jvps[while_p] = _while_loop_jvp
|
||||
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
|
||||
xla.register_translation(while_p, _while_loop_translation_rule,
|
||||
initial_style=True)
|
||||
xla.register_initial_style_primitive(while_p)
|
||||
xla.register_translation(while_p, _while_loop_translation_rule)
|
||||
ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
pe.partial_eval_jaxpr_custom_rules[while_p] = \
|
||||
@ -1342,7 +1342,8 @@ ad.primitive_jvps[cond_p] = _cond_jvp
|
||||
ad.reducing_transposes[cond_p] = _cond_transpose
|
||||
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
|
||||
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
|
||||
xla.register_translation(cond_p, _cond_translation_rule, initial_style=True)
|
||||
xla.register_initial_style_primitive(cond_p)
|
||||
xla.register_translation(cond_p, _cond_translation_rule)
|
||||
core.custom_typechecks[cond_p] = _cond_typecheck
|
||||
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
|
||||
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
|
||||
@ -2132,9 +2133,9 @@ scan_p.def_effectful_abstract_eval(_scan_abstract_eval)
|
||||
ad.primitive_jvps[scan_p] = _scan_jvp
|
||||
ad.reducing_transposes[scan_p] = _scan_transpose
|
||||
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
|
||||
xla.register_initial_style_primitive(scan_p)
|
||||
xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
multiple_results=True))
|
||||
mlir.register_lowering(scan_p,
|
||||
mlir.lower_fun(_scan_impl, multiple_results=True))
|
||||
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
|
||||
@ -2692,10 +2693,10 @@ linear_solve_p.multiple_results = True
|
||||
linear_solve_p.def_impl(_custom_linear_solve_impl)
|
||||
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
|
||||
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
|
||||
xla.register_initial_style_primitive(linear_solve_p)
|
||||
xla.register_translation(
|
||||
linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True,
|
||||
multiple_results=True),
|
||||
initial_style=True)
|
||||
multiple_results=True))
|
||||
mlir.register_lowering(
|
||||
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
|
||||
multiple_results=True))
|
||||
|
@ -783,9 +783,9 @@ psum_p = core.AxisPrimitive('psum')
|
||||
psum_p.multiple_results = True
|
||||
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
|
||||
psum_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.register_collective_primitive(psum_p)
|
||||
xla.register_translation(
|
||||
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum),
|
||||
is_collective=True)
|
||||
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum))
|
||||
mlir.register_lowering(
|
||||
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
|
||||
ad.deflinear2(psum_p, _psum_transpose_rule)
|
||||
@ -822,9 +822,9 @@ pmax_p = core.AxisPrimitive('pmax')
|
||||
pmax_p.multiple_results = True
|
||||
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
|
||||
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.register_collective_primitive(pmax_p)
|
||||
xla.register_translation(
|
||||
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max),
|
||||
is_collective=True)
|
||||
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max))
|
||||
mlir.register_lowering(
|
||||
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
|
||||
pxla.multi_host_supported_collectives.add(pmax_p)
|
||||
@ -838,9 +838,9 @@ pmin_p = core.AxisPrimitive('pmin')
|
||||
pmin_p.multiple_results = True
|
||||
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
|
||||
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
xla.register_collective_primitive(pmin_p)
|
||||
xla.register_translation(
|
||||
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min),
|
||||
is_collective=True)
|
||||
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min))
|
||||
mlir.register_lowering(
|
||||
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
|
||||
pxla.multi_host_supported_collectives.add(pmin_p)
|
||||
@ -910,8 +910,8 @@ def _collective_batcher(prim, args, dims, **params):
|
||||
ppermute_p = core.AxisPrimitive('ppermute')
|
||||
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
|
||||
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
|
||||
xla.register_translation(ppermute_p, _ppermute_translation_rule,
|
||||
is_collective=True)
|
||||
xla.register_collective_primitive(ppermute_p)
|
||||
xla.register_translation(ppermute_p, _ppermute_translation_rule)
|
||||
mlir.register_lowering(ppermute_p, _ppermute_lowering)
|
||||
pxla.multi_host_supported_collectives.add(ppermute_p)
|
||||
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
|
||||
@ -1102,8 +1102,8 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_
|
||||
|
||||
all_to_all_p = core.AxisPrimitive('all_to_all')
|
||||
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
|
||||
xla.register_translation(all_to_all_p, _all_to_all_translation_rule,
|
||||
is_collective=True)
|
||||
xla.register_collective_primitive(all_to_all_p)
|
||||
xla.register_translation(all_to_all_p, _all_to_all_translation_rule)
|
||||
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
|
||||
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(all_to_all_p)
|
||||
@ -1323,8 +1323,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
|
||||
all_gather_p = core.AxisPrimitive('all_gather')
|
||||
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
|
||||
all_gather_p.def_impl(_all_gather_impl)
|
||||
xla.register_translation(all_gather_p, _all_gather_translation_rule,
|
||||
is_collective=True)
|
||||
xla.register_collective_primitive(all_gather_p)
|
||||
xla.register_translation(all_gather_p, _all_gather_translation_rule)
|
||||
mlir.register_lowering(all_gather_p, _all_gather_lowering)
|
||||
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
|
||||
pxla.multi_host_supported_collectives.add(all_gather_p)
|
||||
@ -1462,10 +1462,10 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
|
||||
|
||||
reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
|
||||
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
|
||||
xla.register_collective_primitive(reduce_scatter_p)
|
||||
xla.register_translation(
|
||||
reduce_scatter_p,
|
||||
partial(_reduce_scatter_translation_rule, lax.add_p, psum),
|
||||
is_collective=True)
|
||||
partial(_reduce_scatter_translation_rule, lax.add_p, psum))
|
||||
mlir.register_lowering(
|
||||
reduce_scatter_p,
|
||||
partial(_reduce_scatter_lowering, lax.add_p, psum))
|
||||
@ -1590,8 +1590,8 @@ def _axis_index_abstract_eval(*, axis_name):
|
||||
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
|
||||
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
xla.register_translation(axis_index_p, _axis_index_translation_rule,
|
||||
is_collective=True)
|
||||
xla.register_collective_primitive(axis_index_p)
|
||||
xla.register_translation(axis_index_p, _axis_index_translation_rule)
|
||||
mlir.register_lowering(axis_index_p, _axis_index_lowering)
|
||||
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
|
||||
pxla.multi_host_supported_collectives.add(axis_index_p)
|
||||
@ -1683,10 +1683,10 @@ def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision):
|
||||
precision=precision, preferred_element_type=None)
|
||||
return psum(local_out, axis_name) if axis_name is not None else local_out
|
||||
|
||||
xla.register_collective_primitive(pdot_p)
|
||||
xla.register_translation(
|
||||
pdot_p,
|
||||
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True),
|
||||
is_collective=True)
|
||||
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True))
|
||||
mlir.register_lowering(
|
||||
pdot_p,
|
||||
mlir.lower_fun(_pdot_lowering, multiple_results=False))
|
||||
@ -1785,8 +1785,8 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a
|
||||
pgather_p = core.AxisPrimitive('pgather')
|
||||
pgather_p.def_impl(_pgather_impl)
|
||||
pgather_p.def_abstract_eval(_pgather_abstract_eval)
|
||||
xla.register_translation(pgather_p, _pgather_parallel_translation,
|
||||
is_collective=True)
|
||||
xla.register_collective_primitive(pgather_p)
|
||||
xla.register_translation(pgather_p, _pgather_parallel_translation)
|
||||
mlir.register_lowering(pgather_p, _pgather_parallel_lowering)
|
||||
# TODO: Transpose? That requires adding pscatter...
|
||||
batching.primitive_batchers[pgather_p] = _pgather_batcher
|
||||
|
@ -870,17 +870,17 @@ _backend_specific_translations = defaultdict(dict)
|
||||
_collective_primitives: Set[core.Primitive] = set()
|
||||
_initial_style_primitives: Set[core.Primitive] = set()
|
||||
|
||||
def register_initial_style_primitive(prim: core.Primitive):
|
||||
_initial_style_primitives.add(prim)
|
||||
|
||||
def register_collective_primitive(prim: core.Primitive):
|
||||
_collective_primitives.add(prim)
|
||||
|
||||
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
|
||||
platform: Optional[str] = None,
|
||||
is_collective: bool = False,
|
||||
initial_style: bool = False) -> None:
|
||||
platform: Optional[str] = None) -> None:
|
||||
ts = (_translations if platform is None
|
||||
else _backend_specific_translations[platform])
|
||||
ts[prim] = rule
|
||||
if is_collective:
|
||||
_collective_primitives.add(prim)
|
||||
if initial_style:
|
||||
_initial_style_primitives.add(prim)
|
||||
|
||||
# As a temporary backward compatibility measure, we use an adapter class to
|
||||
# convert from the old styles of translation rules to the newer ones.
|
||||
|
Loading…
x
Reference in New Issue
Block a user