As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations.
There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either:
1. Executing a program that includes any host callbacks, or
2. when running within the body of a callback.
It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented.
This PR includes just the first change, and the second will be implemented in a follow-up.
PiperOrigin-RevId: 720777713
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.
In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.
PiperOrigin-RevId: 713272197
Prior to this change the behavior in eager and under jax.jit was inconsistent
>>> (lambda *args: jax.debug.callback(print, *args))([42])
[42]
>>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
[array(42, dtype=int32)]
It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.
Closes#20627.
PiperOrigin-RevId: 626461904
These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more.
Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices.
PiperOrigin-RevId: 565367453
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
DETAILS:
Due to xc.register_custom_call_partitioner is not supported on libtpu, the following two tests are skipped:
tests/pjit_test.py::PJitTest::test_custom_partitioner
tests/debugging_primitives_test.py::InspectShardingTest::test_inspect_sharding_is_called_in_pjit
DETAILS:
Run on CloudTPU v2-8 and found some tests in debugging_primitives_test
fail due to stream_executor runtime cannot support host callback.
Since host callback only support TFRT, so that skip all those types if
runtime type is stream_executor.
TESTED:
passed unit test on both TPU v2-8 and CPU.