mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[host_callback] Delete unused code paths.
This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561851494
This commit is contained in:
parent
70b58bbd30
commit
e0a6230214
@ -1002,6 +1002,8 @@ def _outside_call_translation_rule(ctx,
|
||||
**params):
|
||||
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
|
||||
assert has_token
|
||||
use_outfeed = _use_outfeed(ctx.platform)
|
||||
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
|
||||
current_token = args_op[-2]
|
||||
current_itoken = args_op[-1]
|
||||
comp = ctx.builder
|
||||
@ -1015,124 +1017,71 @@ def _outside_call_translation_rule(ctx,
|
||||
flat_results_aval))
|
||||
need_callback_results_on_device = (not identity and
|
||||
len(non_empty_flat_results_aval) > 0)
|
||||
use_outfeed = _use_outfeed(ctx.platform)
|
||||
# TODO(sharadmv): Delete non-outfeed path when jaxlib minimum version is
|
||||
# bumped past 0.3.8.
|
||||
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
|
||||
send_infeed = use_outfeed and need_callback_results_on_device
|
||||
generated_infeed = False # Keep track if we emitted an infeed op
|
||||
if use_outfeed:
|
||||
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
|
||||
callback_id = _register_callback(
|
||||
functools.partial(
|
||||
_outside_call_run_callback,
|
||||
send_infeed=send_infeed,
|
||||
identity=identity,
|
||||
flat_results_aval=flat_results_aval,
|
||||
**params))
|
||||
next_token = _callback_handler_data.receiver.add_outfeed(
|
||||
comp, current_token, callback_id, args_to_outfeed, device_index)
|
||||
if identity:
|
||||
results = list(args_to_outfeed)
|
||||
next_itoken = current_itoken
|
||||
else:
|
||||
empty_results = [
|
||||
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
|
||||
for aval in flat_results_aval
|
||||
if _aval_is_empty(aval)
|
||||
]
|
||||
if non_empty_flat_results_aval:
|
||||
assert need_callback_results_on_device
|
||||
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
|
||||
# We shard the infeed as AssignedDevice(device_index). This must match the
|
||||
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
|
||||
# this kind of sharding, we use a custom translation for infeed.
|
||||
array_sharding_proto = xla_client.OpSharding()
|
||||
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
|
||||
array_sharding_proto.tile_assignment_dimensions = [1]
|
||||
array_sharding_proto.tile_assignment_devices = [device_index]
|
||||
|
||||
token_sharding_proto = xla_client.OpSharding()
|
||||
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
|
||||
infeed_sharding_proto = xla.tuple_sharding_proto(
|
||||
[array_sharding_proto] * len(non_empty_flat_results_aval) +
|
||||
[token_sharding_proto])
|
||||
|
||||
shape = [
|
||||
shape.with_major_to_minor_layout_if_absent()
|
||||
for x in non_empty_flat_results_aval
|
||||
for shape in xla.aval_to_xla_shapes(x)
|
||||
]
|
||||
|
||||
build_infeed = functools.partial(xops.InfeedWithToken,
|
||||
after_outfeed_itoken,
|
||||
xla_client.Shape.tuple_shape(shape))
|
||||
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
|
||||
build_infeed)
|
||||
outs = xops.GetTupleElement(outs_and_token, 0)
|
||||
next_itoken = xops.GetTupleElement(outs_and_token, 1)
|
||||
non_empty_results = [
|
||||
xops.GetTupleElement(outs, i)
|
||||
for i in range(len(non_empty_flat_results_aval))
|
||||
]
|
||||
generated_infeed = True
|
||||
results = [
|
||||
empty_results.pop(0)
|
||||
if _aval_is_empty(result_aval) else non_empty_results.pop(0)
|
||||
for result_aval in flat_results_aval
|
||||
]
|
||||
else:
|
||||
results = empty_results
|
||||
next_itoken = current_itoken
|
||||
|
||||
else: # !use_outfeed : CustomCall implementation
|
||||
if device_index != 0:
|
||||
raise ValueError("The device_index feature works only when using outfeed.")
|
||||
|
||||
# TODO(necula): this is a weak attempt to get the device. This works
|
||||
# inside pmap, but does not work when we just execute on a single device,
|
||||
# because in such executions we always get replica_id == 0.
|
||||
replica_id = xla_client.ops.ReplicaId(comp)
|
||||
callback_operands = (current_token, replica_id) + args_to_outfeed
|
||||
if identity:
|
||||
callback_flat_results_aval = (core.abstract_token,)
|
||||
else:
|
||||
callback_flat_results_aval = (core.abstract_token,) + flat_results_aval
|
||||
|
||||
def wrapped_callback(*args):
|
||||
token, replica_id, *arrays = args
|
||||
result_arrays = _outside_call_run_callback(
|
||||
arrays,
|
||||
xb.local_devices()[replica_id],
|
||||
send_infeed=False,
|
||||
# The same parameters as outside_call_p
|
||||
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
|
||||
callback_id = _register_callback(
|
||||
functools.partial(
|
||||
_outside_call_run_callback,
|
||||
send_infeed=send_infeed,
|
||||
identity=identity,
|
||||
flat_results_aval=flat_results_aval,
|
||||
**params)
|
||||
if identity:
|
||||
# For identity, we do not pass the any results back to the device
|
||||
result_arrays = ()
|
||||
return (token,) + result_arrays
|
||||
|
||||
result_shapes = [
|
||||
xla.aval_to_xla_shapes(res_aval)[0]
|
||||
for res_aval in callback_flat_results_aval
|
||||
]
|
||||
backend = ctx.module_context.backend
|
||||
token_and_results_op, keep_alive = backend.emit_python_callback(
|
||||
wrapped_callback,
|
||||
comp,
|
||||
callback_operands,
|
||||
result_shapes,
|
||||
operand_layouts=None,
|
||||
has_side_effects=True)
|
||||
_callback_handler_data.keep_alives.append(keep_alive)
|
||||
next_token, *results = (xops.GetTupleElement(token_and_results_op, i)
|
||||
for i in range(len(callback_flat_results_aval)))
|
||||
# We must put the two tokens at the end
|
||||
if identity:
|
||||
results = list(args_to_outfeed)
|
||||
**params))
|
||||
next_token = _callback_handler_data.receiver.add_outfeed(
|
||||
comp, current_token, callback_id, args_to_outfeed, device_index)
|
||||
if identity:
|
||||
results = list(args_to_outfeed)
|
||||
next_itoken = current_itoken
|
||||
else:
|
||||
empty_results = [
|
||||
xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype))
|
||||
for aval in flat_results_aval
|
||||
if _aval_is_empty(aval)
|
||||
]
|
||||
if non_empty_flat_results_aval:
|
||||
assert need_callback_results_on_device
|
||||
after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token])
|
||||
# We shard the infeed as AssignedDevice(device_index). This must match the
|
||||
# outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support
|
||||
# this kind of sharding, we use a custom translation for infeed.
|
||||
array_sharding_proto = xla_client.OpSharding()
|
||||
array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL
|
||||
array_sharding_proto.tile_assignment_dimensions = [1]
|
||||
array_sharding_proto.tile_assignment_devices = [device_index]
|
||||
|
||||
token_sharding_proto = xla_client.OpSharding()
|
||||
token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED
|
||||
infeed_sharding_proto = xla.tuple_sharding_proto(
|
||||
[array_sharding_proto] * len(non_empty_flat_results_aval) +
|
||||
[token_sharding_proto])
|
||||
|
||||
shape = [
|
||||
shape.with_major_to_minor_layout_if_absent()
|
||||
for x in non_empty_flat_results_aval
|
||||
for shape in xla.aval_to_xla_shapes(x)
|
||||
]
|
||||
|
||||
build_infeed = functools.partial(xops.InfeedWithToken,
|
||||
after_outfeed_itoken,
|
||||
xla_client.Shape.tuple_shape(shape))
|
||||
outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto,
|
||||
build_infeed)
|
||||
outs = xops.GetTupleElement(outs_and_token, 0)
|
||||
next_itoken = xops.GetTupleElement(outs_and_token, 1)
|
||||
non_empty_results = [
|
||||
xops.GetTupleElement(outs, i)
|
||||
for i in range(len(non_empty_flat_results_aval))
|
||||
]
|
||||
generated_infeed = True
|
||||
results = [
|
||||
empty_results.pop(0)
|
||||
if _aval_is_empty(result_aval) else non_empty_results.pop(0)
|
||||
for result_aval in flat_results_aval
|
||||
]
|
||||
else:
|
||||
results = empty_results
|
||||
next_itoken = current_itoken
|
||||
|
||||
assert generated_infeed == send_infeed, (
|
||||
f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")
|
||||
|
@ -961,7 +961,7 @@ jax_test(
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "host_callback_test",
|
||||
name = "host_callback_outfeed_test",
|
||||
srcs = ["host_callback_test.py"],
|
||||
args = ["--jax_host_callback_outfeed=true"],
|
||||
shard_count = {
|
||||
@ -975,13 +975,9 @@ jax_test(
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "host_callback_custom_call_test",
|
||||
name = "host_callback_test",
|
||||
srcs = ["host_callback_test.py"],
|
||||
args = ["--jax_host_callback_outfeed=false"],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu", # On TPU we always use outfeed
|
||||
],
|
||||
main = "host_callback_test.py",
|
||||
shard_count = {
|
||||
"gpu": 5,
|
||||
|
Loading…
x
Reference in New Issue
Block a user