[host_callback] Delete unused code paths.

This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561851494
This commit is contained in:
George Necula 2023-08-31 22:07:42 -07:00 committed by jax authors
parent 70b58bbd30
commit e0a6230214
2 changed files with 63 additions and 118 deletions

View File

@ -1002,6 +1002,8 @@ def _outside_call_translation_rule(ctx,
**params):
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
use_outfeed = _use_outfeed(ctx.platform)
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
current_token = args_op[-2]
current_itoken = args_op[-1]
comp = ctx.builder
@ -1015,13 +1017,9 @@ def _outside_call_translation_rule(ctx,
flat_results_aval))
need_callback_results_on_device = (not identity and
len(non_empty_flat_results_aval) > 0)
use_outfeed = _use_outfeed(ctx.platform)
# TODO(sharadmv): Delete non-outfeed path when jaxlib minimum version is
# bumped past 0.3.8.
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
send_infeed = use_outfeed and need_callback_results_on_device
generated_infeed = False # Keep track if we emitted an infeed op
if use_outfeed:
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
callback_id = _register_callback(
functools.partial(
@ -1085,55 +1083,6 @@ def _outside_call_translation_rule(ctx,
results = empty_results
next_itoken = current_itoken
else: # !use_outfeed : CustomCall implementation
if device_index != 0:
raise ValueError("The device_index feature works only when using outfeed.")
# TODO(necula): this is a weak attempt to get the device. This works
# inside pmap, but does not work when we just execute on a single device,
# because in such executions we always get replica_id == 0.
replica_id = xla_client.ops.ReplicaId(comp)
callback_operands = (current_token, replica_id) + args_to_outfeed
if identity:
callback_flat_results_aval = (core.abstract_token,)
else:
callback_flat_results_aval = (core.abstract_token,) + flat_results_aval
def wrapped_callback(*args):
token, replica_id, *arrays = args
result_arrays = _outside_call_run_callback(
arrays,
xb.local_devices()[replica_id],
send_infeed=False,
# The same parameters as outside_call_p
identity=identity,
flat_results_aval=flat_results_aval,
**params)
if identity:
# For identity, we do not pass the any results back to the device
result_arrays = ()
return (token,) + result_arrays
result_shapes = [
xla.aval_to_xla_shapes(res_aval)[0]
for res_aval in callback_flat_results_aval
]
backend = ctx.module_context.backend
token_and_results_op, keep_alive = backend.emit_python_callback(
wrapped_callback,
comp,
callback_operands,
result_shapes,
operand_layouts=None,
has_side_effects=True)
_callback_handler_data.keep_alives.append(keep_alive)
next_token, *results = (xops.GetTupleElement(token_and_results_op, i)
for i in range(len(callback_flat_results_aval)))
# We must put the two tokens at the end
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
assert generated_infeed == send_infeed, (
f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")
assert identity or len(results) == len(flat_results_aval), (

View File

@ -961,7 +961,7 @@ jax_test(
)
jax_test(
name = "host_callback_test",
name = "host_callback_outfeed_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=true"],
shard_count = {
@ -975,13 +975,9 @@ jax_test(
)
jax_test(
name = "host_callback_custom_call_test",
name = "host_callback_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=false"],
disable_backends = [
"gpu",
"tpu", # On TPU we always use outfeed
],
main = "host_callback_test.py",
shard_count = {
"gpu": 5,