mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make infeed_test and host_callback_test independent. (#3676)
* Make infeed_test and host_callback_test independent. * the infeed_test will stop the outfeed receiver * Remove the use of --dist=loadfile. * Prevent logging on exit
This commit is contained in:
parent
d2ebb6eb19
commit
bf97e47929
@ -876,6 +876,8 @@ def _initialize_outfeed_receiver(
|
||||
max_callback_queue_size_bytes)
|
||||
|
||||
def exit_handler():
|
||||
# Prevent logging usage during compilation, gives errors under pytest
|
||||
xla._on_exit = True
|
||||
logging.vlog(2, "Barrier wait atexit")
|
||||
barrier_wait()
|
||||
|
||||
|
@ -66,6 +66,9 @@ flags.DEFINE_bool('jax_log_compiles',
|
||||
bool_env('JAX_LOG_COMPILES', False),
|
||||
'Print a message each time a `jit` computation is compiled.')
|
||||
|
||||
# This flag is set on exit; no logging should be attempted
|
||||
_on_exit = False
|
||||
|
||||
def identity(x): return x
|
||||
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
@ -613,8 +616,9 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
|
||||
if not jaxpr.eqns:
|
||||
return partial(_execute_trivial, jaxpr, device, consts, result_handlers)
|
||||
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args)
|
||||
if not _on_exit:
|
||||
log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
|
||||
logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args)
|
||||
|
||||
if nreps > 1:
|
||||
warn(f"The jitted function {fun.__name__} includes a pmap. Using "
|
||||
|
@ -10,8 +10,5 @@ filterwarnings =
|
||||
ignore:can't resolve package from __spec__ or __package__:ImportWarning
|
||||
ignore:Using or importing the ABCs.*:DeprecationWarning
|
||||
doctest_optionflags = NUMBER NORMALIZE_WHITESPACE
|
||||
addopts = --doctest-glob="*.rst" --dist=loadfile
|
||||
# --dist=loadfile ensure that all the tests in one file are sent to the same runner. This is useful
|
||||
# for host_callback_test which start and then stop on teardown the C++ outfeed receiver
|
||||
# runtime. If we do not stop the receiver, other tests that use outfeed are going to fail.
|
||||
addopts = --doctest-glob="*.rst"
|
||||
|
||||
|
@ -124,10 +124,6 @@ class HostCallbackTest(jtu.JaxTestCase):
|
||||
xla_bridge.get_backend.cache_clear()
|
||||
hcb.barrier_wait()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
hcb.stop_outfeed_receiver()
|
||||
|
||||
def helper_set_devices(self, nr_devices):
|
||||
flags_str = os.getenv("XLA_FLAGS", "")
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
@ -962,6 +958,7 @@ what: x times i
|
||||
Check that we get the proper error from the runtime."""
|
||||
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
||||
token = hcb.xops.CreateToken(comp)
|
||||
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Consumer ID cannot be a reserved value: 0"):
|
||||
hcb._outfeed_receiver.receiver.add_outfeed(
|
||||
@ -972,6 +969,7 @@ what: x times i
|
||||
"""Try to register different shapes for the same consumer ID."""
|
||||
comp = xla_bridge.make_computation_builder(self._testMethodName)
|
||||
token = hcb.xops.CreateToken(comp)
|
||||
hcb._initialize_outfeed_receiver() # Needed if this is the sole test
|
||||
hcb._outfeed_receiver.receiver.add_outfeed(
|
||||
comp, token, 123,
|
||||
[xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
|
||||
|
Loading…
x
Reference in New Issue
Block a user