From b666f665ece2b91e7c07c65d27cf485a779638b6 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Sun, 10 Jul 2022 13:04:44 -0700 Subject: [PATCH] Rollback of HCB GPU custom call due to internal failures PiperOrigin-RevId: 460079787 --- jax/experimental/host_callback.py | 4 ---- tests/BUILD | 1 + tests/host_callback_test.py | 2 -- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 16fe2b271..9e36cbeb8 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -537,8 +537,6 @@ def _inline_host_callback() -> bool: def _use_outfeed(platform: str) -> bool: - if jaxlib.version >= (0, 3, 15): - return platform == "tpu" or FLAGS.jax_host_callback_outfeed return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed) xops = xla_client._xla.ops @@ -1200,8 +1198,6 @@ def _outside_call_lowering( return results + [next_token, next_itoken] mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu") -if jaxlib.version >= (0, 3, 15): - mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="gpu") def _outside_call_run_callback( arrays, device, *, diff --git a/tests/BUILD b/tests/BUILD index ccd5cfdd5..6c990df86 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -783,6 +783,7 @@ jax_test( srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=false"], disable_backends = [ + "gpu", "tpu", # On TPU we always use outfeed ], main = "host_callback_test.py", diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index b57c55a9e..54b3c99ef 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -2456,8 +2456,6 @@ class HostCallbackCallTest(jtu.JaxTestCase): expected_exc_txt: str): """Calls thunk() and checks for expected exceptions. """ - if not FLAGS.jax_host_callback_outfeed: - raise SkipTest("TODO: implement error handling for customcall") if jtu.device_under_test() == "cpu": # On CPU the runtime crashes, and the tests are all aborted raise SkipTest("TODO: CPU runtime crashes on unexpected infeed")