[host_callback] Delete unused code paths.

This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561851494
This commit is contained in:
George Necula 2023-08-31 22:07:42 -07:00 committed by jax authors
parent 70b58bbd30
commit e0a6230214
2 changed files with 63 additions and 118 deletions

View File

@ -1002,6 +1002,8 @@ def _outside_call_translation_rule(ctx,
**params):
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
use_outfeed = _use_outfeed(ctx.platform)
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
current_token = args_op[-2]
current_itoken = args_op[-1]
comp = ctx.builder
@ -1015,124 +1017,71 @@ def _outside_call_translation_rule(ctx,
flat_results_aval))
need_callback_results_on_device = (not identity and
len(non_empty_flat_results_aval) > 0)
use_outfeed = _use_outfeed(ctx.platform)
# TODO(sharadmv): Delete non-outfeed path when jaxlib minimum version is
# bumped past 0.3.8.
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
send_infeed = use_outfeed and need_callback_results_on_device
generated_infeed = False # Keep track if we emitted an infeed op
if use_outfeed:
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
callback_id = _register_callback(
functools.partial(
_outside_call_run_callback,
send_infeed=send_infeed,
identity=identity,
flat_results_aval=flat_results_aval,
**params))
next_token = _callback_handler_data.receiver.add_outfeed(
comp, current_token, callback_id, args_to_outfeed, device_index)
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
else:
empty_results = [
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
for aval in flat_results_aval
if _aval_is_empty(aval)
]
if non_empty_flat_results_aval:
assert need_callback_results_on_device
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
# We shard the infeed as AssignedDevice(device_index). This must match the
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
# this kind of sharding, we use a custom translation for infeed.
array_sharding_proto = xla_client.OpSharding()
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
array_sharding_proto.tile_assignment_dimensions = [1]
array_sharding_proto.tile_assignment_devices = [device_index]
token_sharding_proto = xla_client.OpSharding()
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
infeed_sharding_proto = xla.tuple_sharding_proto(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])
shape = [
shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x)
]
build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
xla_client.Shape.tuple_shape(shape))
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs = xops.GetTupleElement(outs_and_token, 0)
next_itoken = xops.GetTupleElement(outs_and_token, 1)
non_empty_results = [
xops.GetTupleElement(outs, i)
for i in range(len(non_empty_flat_results_aval))
]
generated_infeed = True
results = [
empty_results.pop(0)
if _aval_is_empty(result_aval) else non_empty_results.pop(0)
for result_aval in flat_results_aval
]
else:
results = empty_results
next_itoken = current_itoken
else: # !use_outfeed : CustomCall implementation
if device_index != 0:
raise ValueError("The device_index feature works only when using outfeed.")
# TODO(necula): this is a weak attempt to get the device. This works
# inside pmap, but does not work when we just execute on a single device,
# because in such executions we always get replica_id == 0.
replica_id = xla_client.ops.ReplicaId(comp)
callback_operands = (current_token, replica_id) + args_to_outfeed
if identity:
callback_flat_results_aval = (core.abstract_token,)
else:
callback_flat_results_aval = (core.abstract_token,) + flat_results_aval
def wrapped_callback(*args):
token, replica_id, *arrays = args
result_arrays = _outside_call_run_callback(
arrays,
xb.local_devices()[replica_id],
send_infeed=False,
# The same parameters as outside_call_p
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
callback_id = _register_callback(
functools.partial(
_outside_call_run_callback,
send_infeed=send_infeed,
identity=identity,
flat_results_aval=flat_results_aval,
**params)
if identity:
# For identity, we do not pass the any results back to the device
result_arrays = ()
return (token,) + result_arrays
result_shapes = [
xla.aval_to_xla_shapes(res_aval)[0]
for res_aval in callback_flat_results_aval
]
backend = ctx.module_context.backend
token_and_results_op, keep_alive = backend.emit_python_callback(
wrapped_callback,
comp,
callback_operands,
result_shapes,
operand_layouts=None,
has_side_effects=True)
_callback_handler_data.keep_alives.append(keep_alive)
next_token, *results = (xops.GetTupleElement(token_and_results_op, i)
for i in range(len(callback_flat_results_aval)))
# We must put the two tokens at the end
if identity:
results = list(args_to_outfeed)
**params))
next_token = _callback_handler_data.receiver.add_outfeed(
comp, current_token, callback_id, args_to_outfeed, device_index)
if identity:
results = list(args_to_outfeed)
next_itoken = current_itoken
else:
empty_results = [
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
for aval in flat_results_aval
if _aval_is_empty(aval)
]
if non_empty_flat_results_aval:
assert need_callback_results_on_device
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
# We shard the infeed as AssignedDevice(device_index). This must match the
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
# this kind of sharding, we use a custom translation for infeed.
array_sharding_proto = xla_client.OpSharding()
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
array_sharding_proto.tile_assignment_dimensions = [1]
array_sharding_proto.tile_assignment_devices = [device_index]
token_sharding_proto = xla_client.OpSharding()
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
infeed_sharding_proto = xla.tuple_sharding_proto(
[array_sharding_proto] * len(non_empty_flat_results_aval) +
[token_sharding_proto])
shape = [
shape.with_major_to_minor_layout_if_absent()
for x in non_empty_flat_results_aval
for shape in xla.aval_to_xla_shapes(x)
]
build_infeed = functools.partial(xops.InfeedWithToken,
after_outfeed_itoken,
xla_client.Shape.tuple_shape(shape))
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
build_infeed)
outs = xops.GetTupleElement(outs_and_token, 0)
next_itoken = xops.GetTupleElement(outs_and_token, 1)
non_empty_results = [
xops.GetTupleElement(outs, i)
for i in range(len(non_empty_flat_results_aval))
]
generated_infeed = True
results = [
empty_results.pop(0)
if _aval_is_empty(result_aval) else non_empty_results.pop(0)
for result_aval in flat_results_aval
]
else:
results = empty_results
next_itoken = current_itoken
assert generated_infeed == send_infeed, (
f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")

View File

@ -961,7 +961,7 @@ jax_test(
)
jax_test(
name = "host_callback_test",
name = "host_callback_outfeed_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=true"],
shard_count = {
@ -975,13 +975,9 @@ jax_test(
)
jax_test(
name = "host_callback_custom_call_test",
name = "host_callback_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=false"],
disable_backends = [
"gpu",
"tpu", # On TPU we always use outfeed
],
main = "host_callback_test.py",
shard_count = {
"gpu": 5,