mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[callback] Enable device_index support in terms of callback sharding support.
This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561856023
This commit is contained in:
parent
e0a6230214
commit
efaea8ed32
@ -503,7 +503,6 @@ import math
|
||||
import threading
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional, cast
|
||||
import warnings
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
@ -1117,8 +1116,11 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
|
||||
device_index=device_index,
|
||||
**params)
|
||||
else:
|
||||
if device_index != 0:
|
||||
raise ValueError("The device_index feature works only when using outfeed.")
|
||||
# TODO(necula): It seems that on CPU, with custom call, the device_index
|
||||
# does not work, and the callback is always run on device_index=0
|
||||
if (device_index != 0 and ctx.module_context.platform == "cpu"):
|
||||
raise ValueError(
|
||||
"The device_index feature on CPU works only when using outfeed.")
|
||||
# We expect the current tokens at the end, inserted by _rewrite_jaxpr.
|
||||
assert has_token
|
||||
current_token = args[-2]
|
||||
|
@ -1543,7 +1543,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
if (device_index != 0 and
|
||||
not FLAGS.jax_host_callback_outfeed and
|
||||
jtu.device_under_test() == "cpu"):
|
||||
raise SkipTest("device_index works only with outfeed")
|
||||
# See comment in host_callback.py.
|
||||
raise SkipTest("device_index works only with outfeed on CPU")
|
||||
|
||||
devices = np.array(local_devices())
|
||||
nr_devices = len(devices)
|
||||
|
Loading…
x
Reference in New Issue
Block a user