mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[host_callback] Delete unused code paths.
This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561851494
This commit is contained in:
parent
70b58bbd30
commit
e0a6230214
@ -1002,6 +1002,8 @@ def _outside_call_translation_rule(ctx,
|
||||
**params):
|
||||
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
|
||||
assert has_token
|
||||
use_outfeed = _use_outfeed(ctx.platform)
|
||||
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
|
||||
current_token = args_op[-2]
|
||||
current_itoken = args_op[-1]
|
||||
comp = ctx.builder
|
||||
@ -1015,13 +1017,9 @@ def _outside_call_translation_rule(ctx,
|
||||
flat_results_aval))
|
||||
need_callback_results_on_device = (not identity and
|
||||
len(non_empty_flat_results_aval) > 0)
|
||||
use_outfeed = _use_outfeed(ctx.platform)
|
||||
# TODO(sharadmv): Delete non-outfeed path when jaxlib minimum version is
|
||||
# bumped past 0.3.8.
|
||||
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
|
||||
send_infeed = use_outfeed and need_callback_results_on_device
|
||||
generated_infeed = False # Keep track if we emitted an infeed op
|
||||
if use_outfeed:
|
||||
|
||||
_raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform))
|
||||
callback_id = _register_callback(
|
||||
functools.partial(
|
||||
@ -1085,55 +1083,6 @@ def _outside_call_translation_rule(ctx,
|
||||
results = empty_results
|
||||
next_itoken = current_itoken
|
||||
|
||||
else: # !use_outfeed : CustomCall implementation
|
||||
if device_index != 0:
|
||||
raise ValueError("The device_index feature works only when using outfeed.")
|
||||
|
||||
# TODO(necula): this is a weak attempt to get the device. This works
|
||||
# inside pmap, but does not work when we just execute on a single device,
|
||||
# because in such executions we always get replica_id == 0.
|
||||
replica_id = xla_client.ops.ReplicaId(comp)
|
||||
callback_operands = (current_token, replica_id) + args_to_outfeed
|
||||
if identity:
|
||||
callback_flat_results_aval = (core.abstract_token,)
|
||||
else:
|
||||
callback_flat_results_aval = (core.abstract_token,) + flat_results_aval
|
||||
|
||||
def wrapped_callback(*args):
|
||||
token, replica_id, *arrays = args
|
||||
result_arrays = _outside_call_run_callback(
|
||||
arrays,
|
||||
xb.local_devices()[replica_id],
|
||||
send_infeed=False,
|
||||
# The same parameters as outside_call_p
|
||||
identity=identity,
|
||||
flat_results_aval=flat_results_aval,
|
||||
**params)
|
||||
if identity:
|
||||
# For identity, we do not pass the any results back to the device
|
||||
result_arrays = ()
|
||||
return (token,) + result_arrays
|
||||
|
||||
result_shapes = [
|
||||
xla.aval_to_xla_shapes(res_aval)[0]
|
||||
for res_aval in callback_flat_results_aval
|
||||
]
|
||||
backend = ctx.module_context.backend
|
||||
token_and_results_op, keep_alive = backend.emit_python_callback(
|
||||
wrapped_callback,
|
||||
comp,
|
||||
callback_operands,
|
||||
result_shapes,
|
||||
operand_layouts=None,
|
||||
has_side_effects=True)
|
||||
_callback_handler_data.keep_alives.append(keep_alive)
|
||||
next_token, *results = (xops.GetTupleElement(token_and_results_op, i)
|
||||
for i in range(len(callback_flat_results_aval)))
|
||||
# We must put the two tokens at the end
|
||||
if identity:
|
||||
results = list(args_to_outfeed)
|
||||
next_itoken = current_itoken
|
||||
|
||||
assert generated_infeed == send_infeed, (
|
||||
f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})")
|
||||
assert identity or len(results) == len(flat_results_aval), (
|
||||
|
@ -961,7 +961,7 @@ jax_test(
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "host_callback_test",
|
||||
name = "host_callback_outfeed_test",
|
||||
srcs = ["host_callback_test.py"],
|
||||
args = ["--jax_host_callback_outfeed=true"],
|
||||
shard_count = {
|
||||
@ -975,13 +975,9 @@ jax_test(
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "host_callback_custom_call_test",
|
||||
name = "host_callback_test",
|
||||
srcs = ["host_callback_test.py"],
|
||||
args = ["--jax_host_callback_outfeed=false"],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu", # On TPU we always use outfeed
|
||||
],
|
||||
main = "host_callback_test.py",
|
||||
shard_count = {
|
||||
"gpu": 5,
|
||||
|
Loading…
x
Reference in New Issue
Block a user