Make infeed_test and host_callback_test independent. (#3676)

* Make infeed_test and host_callback_test independent.

* the infeed_test will stop the outfeed receiver
* Remove the use of --dist=loadfile.
* Prevent logging on exit
This commit is contained in:
George Necula 2020-07-07 11:03:30 +03:00 committed by GitHub
parent d2ebb6eb19
commit bf97e47929
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 10 deletions

View File

@ -876,6 +876,8 @@ def _initialize_outfeed_receiver(
max_callback_queue_size_bytes)
def exit_handler():
# Prevent logging usage during compilation, gives errors under pytest
xla._on_exit = True
logging.vlog(2, "Barrier wait atexit")
barrier_wait()

View File

@ -66,6 +66,9 @@ flags.DEFINE_bool('jax_log_compiles',
bool_env('JAX_LOG_COMPILES', False),
'Print a message each time a `jit` computation is compiled.')
# This flag is set on exit; no logging should be attempted
_on_exit = False
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
@ -613,8 +616,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
if not jaxpr.eqns:
return partial(_execute_trivial, jaxpr, device, consts, result_handlers)
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args)
if not _on_exit:
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args)
if nreps > 1:
warn(f"The jitted function {fun.__name__} includes a pmap. Using "

View File

@ -10,8 +10,5 @@ filterwarnings =
ignore:can't resolve package from __spec__ or __package__:ImportWarning
ignore:Using or importing the ABCs.*:DeprecationWarning
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-glob="*.rst" --dist=loadfile
# --dist=loadfile ensure that all the tests in one file are sent to the same runner. This is useful
# for host_callback_test which start and then stop on teardown the C++ outfeed receiver
# runtime. If we do not stop the receiver, other tests that use outfeed are going to fail.
addopts = --doctest-glob="*.rst"

View File

@ -124,10 +124,6 @@ class HostCallbackTest(jtu.JaxTestCase):
xla_bridge.get_backend.cache_clear()
hcb.barrier_wait()
@classmethod
def tearDownClass(cls):
hcb.stop_outfeed_receiver()
def helper_set_devices(self, nr_devices):
flags_str = os.getenv("XLA_FLAGS", "")
os.environ["XLA_FLAGS"] = (
@ -962,6 +958,7 @@ what: x times i
Check that we get the proper error from the runtime."""
comp = xla_bridge.make_computation_builder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
with self.assertRaisesRegex(RuntimeError,
"Consumer ID cannot be a reserved value: 0"):
hcb._outfeed_receiver.receiver.add_outfeed(
@ -972,6 +969,7 @@ what: x times i
"""Try to register different shapes for the same consumer ID."""
comp = xla_bridge.make_computation_builder(self._testMethodName)
token = hcb.xops.CreateToken(comp)
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
hcb._outfeed_receiver.receiver.add_outfeed(
comp, token, 123,
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])