Rollback of HCB GPU custom call due to internal failures

PiperOrigin-RevId: 460079787
This commit is contained in:
Sharad Vikram 2022-07-10 13:04:44 -07:00 committed by jax authors
parent 5910cdc861
commit b666f665ec
3 changed files with 1 additions and 6 deletions

View File

@ -537,8 +537,6 @@ def _inline_host_callback() -> bool:
def _use_outfeed(platform: str) -> bool:
if jaxlib.version >= (0, 3, 15):
return platform == "tpu" or FLAGS.jax_host_callback_outfeed
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed)
xops = xla_client._xla.ops
@ -1200,8 +1198,6 @@ def _outside_call_lowering(
return results + [next_token, next_itoken]
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
if jaxlib.version >= (0, 3, 15):
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="gpu")
def _outside_call_run_callback(
arrays, device, *,

View File

@ -783,6 +783,7 @@ jax_test(
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=false"],
disable_backends = [
"gpu",
"tpu", # On TPU we always use outfeed
],
main = "host_callback_test.py",

View File

@ -2456,8 +2456,6 @@ class HostCallbackCallTest(jtu.JaxTestCase):
expected_exc_txt: str):
"""Calls thunk() and checks for expected exceptions.
"""
if not FLAGS.jax_host_callback_outfeed:
raise SkipTest("TODO: implement error handling for customcall")
if jtu.device_under_test() == "cpu":
# On CPU the runtime crashes, and the tests are all aborted
raise SkipTest("TODO: CPU runtime crashes on unexpected infeed")