mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
This commit is contained in:
parent
5910cdc861
commit
b666f665ec
@ -537,8 +537,6 @@ def _inline_host_callback() -> bool:
|
||||
|
||||
|
||||
def _use_outfeed(platform: str) -> bool:
|
||||
if jaxlib.version >= (0, 3, 15):
|
||||
return platform == "tpu" or FLAGS.jax_host_callback_outfeed
|
||||
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed)
|
||||
|
||||
xops = xla_client._xla.ops
|
||||
@ -1200,8 +1198,6 @@ def _outside_call_lowering(
|
||||
return results + [next_token, next_itoken]
|
||||
|
||||
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
|
||||
if jaxlib.version >= (0, 3, 15):
|
||||
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="gpu")
|
||||
|
||||
def _outside_call_run_callback(
|
||||
arrays, device, *,
|
||||
|
@ -783,6 +783,7 @@ jax_test(
|
||||
srcs = ["host_callback_test.py"],
|
||||
args = ["--jax_host_callback_outfeed=false"],
|
||||
disable_backends = [
|
||||
"gpu",
|
||||
"tpu", # On TPU we always use outfeed
|
||||
],
|
||||
main = "host_callback_test.py",
|
||||
|
@ -2456,8 +2456,6 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
expected_exc_txt: str):
|
||||
"""Calls thunk() and checks for expected exceptions.
|
||||
"""
|
||||
if not FLAGS.jax_host_callback_outfeed:
|
||||
raise SkipTest("TODO: implement error handling for customcall")
|
||||
if jtu.device_under_test() == "cpu":
|
||||
# On CPU the runtime crashes, and the tests are all aborted
|
||||
raise SkipTest("TODO: CPU runtime crashes on unexpected infeed")
|
||||
|
Loading…
x
Reference in New Issue
Block a user