From e0a62302140a73185b380b5310f7fc7cbeb658c7 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 31 Aug 2023 22:07:42 -0700 Subject: [PATCH] [host_callback] Delete unused code paths. This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561851494 --- jax/experimental/host_callback.py | 173 +++++++++++------------------- tests/BUILD | 8 +- 2 files changed, 63 insertions(+), 118 deletions(-) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 462159f52..9ae1a7843 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1002,6 +1002,8 @@ def _outside_call_translation_rule(ctx, **params): # We expect the current tokens at the end, inserted by _rewrite_jaxpr. assert has_token + use_outfeed = _use_outfeed(ctx.platform) + assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering' current_token = args_op[-2] current_itoken = args_op[-1] comp = ctx.builder @@ -1015,124 +1017,71 @@ def _outside_call_translation_rule(ctx, flat_results_aval)) need_callback_results_on_device = (not identity and len(non_empty_flat_results_aval) > 0) - use_outfeed = _use_outfeed(ctx.platform) - # TODO(sharadmv): Delete non-outfeed path when jaxlib minimum version is - # bumped past 0.3.8. - assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering' send_infeed = use_outfeed and need_callback_results_on_device generated_infeed = False # Keep track if we emitted an infeed op - if use_outfeed: - _raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform)) - callback_id = _register_callback( - functools.partial( - _outside_call_run_callback, - send_infeed=send_infeed, - identity=identity, - flat_results_aval=flat_results_aval, - **params)) - next_token = _callback_handler_data.receiver.add_outfeed( - comp, current_token, callback_id, args_to_outfeed, device_index) - if identity: - results = list(args_to_outfeed) - next_itoken = current_itoken - else: - empty_results = [ - xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype)) - for aval in flat_results_aval - if _aval_is_empty(aval) - ] - if non_empty_flat_results_aval: - assert need_callback_results_on_device - after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token]) - # We shard the infeed as AssignedDevice(device_index). This must match the - # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support - # this kind of sharding, we use a custom translation for infeed. - array_sharding_proto = xla_client.OpSharding() - array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL - array_sharding_proto.tile_assignment_dimensions = [1] - array_sharding_proto.tile_assignment_devices = [device_index] - token_sharding_proto = xla_client.OpSharding() - token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED - infeed_sharding_proto = xla.tuple_sharding_proto( - [array_sharding_proto] * len(non_empty_flat_results_aval) + - [token_sharding_proto]) - - shape = [ - shape.with_major_to_minor_layout_if_absent() - for x in non_empty_flat_results_aval - for shape in xla.aval_to_xla_shapes(x) - ] - - build_infeed = functools.partial(xops.InfeedWithToken, - after_outfeed_itoken, - xla_client.Shape.tuple_shape(shape)) - outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto, - build_infeed) - outs = xops.GetTupleElement(outs_and_token, 0) - next_itoken = xops.GetTupleElement(outs_and_token, 1) - non_empty_results = [ - xops.GetTupleElement(outs, i) - for i in range(len(non_empty_flat_results_aval)) - ] - generated_infeed = True - results = [ - empty_results.pop(0) - if _aval_is_empty(result_aval) else non_empty_results.pop(0) - for result_aval in flat_results_aval - ] - else: - results = empty_results - next_itoken = current_itoken - - else: # !use_outfeed : CustomCall implementation - if device_index != 0: - raise ValueError("The device_index feature works only when using outfeed.") - - # TODO(necula): this is a weak attempt to get the device. This works - # inside pmap, but does not work when we just execute on a single device, - # because in such executions we always get replica_id == 0. - replica_id = xla_client.ops.ReplicaId(comp) - callback_operands = (current_token, replica_id) + args_to_outfeed - if identity: - callback_flat_results_aval = (core.abstract_token,) - else: - callback_flat_results_aval = (core.abstract_token,) + flat_results_aval - - def wrapped_callback(*args): - token, replica_id, *arrays = args - result_arrays = _outside_call_run_callback( - arrays, - xb.local_devices()[replica_id], - send_infeed=False, - # The same parameters as outside_call_p + _raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform)) + callback_id = _register_callback( + functools.partial( + _outside_call_run_callback, + send_infeed=send_infeed, identity=identity, flat_results_aval=flat_results_aval, - **params) - if identity: - # For identity, we do not pass the any results back to the device - result_arrays = () - return (token,) + result_arrays - - result_shapes = [ - xla.aval_to_xla_shapes(res_aval)[0] - for res_aval in callback_flat_results_aval - ] - backend = ctx.module_context.backend - token_and_results_op, keep_alive = backend.emit_python_callback( - wrapped_callback, - comp, - callback_operands, - result_shapes, - operand_layouts=None, - has_side_effects=True) - _callback_handler_data.keep_alives.append(keep_alive) - next_token, *results = (xops.GetTupleElement(token_and_results_op, i) - for i in range(len(callback_flat_results_aval))) - # We must put the two tokens at the end - if identity: - results = list(args_to_outfeed) + **params)) + next_token = _callback_handler_data.receiver.add_outfeed( + comp, current_token, callback_id, args_to_outfeed, device_index) + if identity: + results = list(args_to_outfeed) next_itoken = current_itoken + else: + empty_results = [ + xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype)) + for aval in flat_results_aval + if _aval_is_empty(aval) + ] + if non_empty_flat_results_aval: + assert need_callback_results_on_device + after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token]) + # We shard the infeed as AssignedDevice(device_index). This must match the + # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support + # this kind of sharding, we use a custom translation for infeed. + array_sharding_proto = xla_client.OpSharding() + array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL + array_sharding_proto.tile_assignment_dimensions = [1] + array_sharding_proto.tile_assignment_devices = [device_index] + + token_sharding_proto = xla_client.OpSharding() + token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED + infeed_sharding_proto = xla.tuple_sharding_proto( + [array_sharding_proto] * len(non_empty_flat_results_aval) + + [token_sharding_proto]) + + shape = [ + shape.with_major_to_minor_layout_if_absent() + for x in non_empty_flat_results_aval + for shape in xla.aval_to_xla_shapes(x) + ] + + build_infeed = functools.partial(xops.InfeedWithToken, + after_outfeed_itoken, + xla_client.Shape.tuple_shape(shape)) + outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto, + build_infeed) + outs = xops.GetTupleElement(outs_and_token, 0) + next_itoken = xops.GetTupleElement(outs_and_token, 1) + non_empty_results = [ + xops.GetTupleElement(outs, i) + for i in range(len(non_empty_flat_results_aval)) + ] + generated_infeed = True + results = [ + empty_results.pop(0) + if _aval_is_empty(result_aval) else non_empty_results.pop(0) + for result_aval in flat_results_aval + ] + else: + results = empty_results + next_itoken = current_itoken assert generated_infeed == send_infeed, ( f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})") diff --git a/tests/BUILD b/tests/BUILD index 1af03063a..0f897e257 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -961,7 +961,7 @@ jax_test( ) jax_test( - name = "host_callback_test", + name = "host_callback_outfeed_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=true"], shard_count = { @@ -975,13 +975,9 @@ jax_test( ) jax_test( - name = "host_callback_custom_call_test", + name = "host_callback_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=false"], - disable_backends = [ - "gpu", - "tpu", # On TPU we always use outfeed - ], main = "host_callback_test.py", shard_count = { "gpu": 5,