[callback] Enable device_index support in terms of callback sharding support.

This is part of deprecating host_callback and moving to io_callback.

PiperOrigin-RevId: 561856023
This commit is contained in:
George Necula 2023-08-31 22:30:59 -07:00 committed by jax authors
parent e0a6230214
commit efaea8ed32
2 changed files with 7 additions and 4 deletions

View File

@ -503,7 +503,6 @@ import math
import threading
import traceback
from typing import Any, Callable, Optional, cast
import warnings
from jax._src import api
from jax._src import core
@ -1117,8 +1116,11 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
device_index=device_index,
**params)
else:
if device_index != 0:
raise ValueError("The device_index feature works only when using outfeed.")
# TODO(necula): It seems that on CPU, with custom call, the device_index
# does not work, and the callback is always run on device_index=0
if (device_index != 0 and ctx.module_context.platform == "cpu"):
raise ValueError(
"The device_index feature on CPU works only when using outfeed.")
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
assert has_token
current_token = args[-2]

View File

@ -1543,7 +1543,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
if (device_index != 0 and
not FLAGS.jax_host_callback_outfeed and
jtu.device_under_test() == "cpu"):
raise SkipTest("device_index works only with outfeed")
# See comment in host_callback.py.
raise SkipTest("device_index works only with outfeed on CPU")
devices = np.array(local_devices())
nr_devices = len(devices)