At the moment, xmap SPMD lowering only enforces sharding constraints for
computation inputs and outputs, while leaving sharding propagation in the
body entirely up to the XLA SPMD partitioner. This patch adds a new flag
`experimental_xmap_enforce_inferred_sharding` that inserts additional
sharding constraint between every JAX primitive in the xmapped function.
Assuming that the SPMD partitioner never overrides user-defined constraints,
this should restrict it sufficiently to generate a computation that is
partitioned exactly as implied by the evolution of intermediate named shapes.
PiperOrigin-RevId: 385562158
Previously there was one thread per device for receiving the outfeed from
devices, but there was a single global thread that was calling into the Python
callbacks. This meant that if one of the callbacks was slow, it was blocking
processing of all other callbacks.
One situation when this created difficulties was if one wanted to break a host_callback into two operations: a quick one to enqueue work on a threadpool,
and a subsequent slow one to wait for and retreive the result. The first slow callback would block all other callbacks, including possibly some quick ones, thus missing the opportunity to start the slow work.
With this change there is a separate queue of outfeeds for each device and a
separate thread per device to call into Python. This allows for concurrency
between callbacks from different devices, although the callbacks from one
device are still sequential. If the programmer wants more concurrency, they can use a threadpool. Having more concurrency by default is tricky, because it may mean that the Python callbacks for one device may be seen out of order.
PiperOrigin-RevId: 385493070
When running across 8xV100 GPUs we observed the following error:
libc++abi: terminating with uncaught exception of type std::runtime_error: third_party/py/jax/jaxlib/cusolver.cc:171: operation cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, static_cast<float*>(workspace), d.lwork, info) failed: cuSolver execution failed
I cannot find documentation to this effect, but I believe that it is unsafe to share cuSolver handles across streams, since keeping the handle pool stream local does solve the issue.
A small bug which surfaces in error messages: if a function has been `functools.wrapped`, some error messages (eg. UnexpectedTracerError) will point to the name of the wrapped function but to the filename of the wrapping function (because `functools.wrapped` does not update the code object of the function, but it _does_ update the name of the function).
A user reported that with their Quadro M4000 GPU (Driver: 460.56) tridiagonal_solve was throwing an "unsupported operation" error. I improved the logging (also included in this patch) and tracked it down to:
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported
I had some challenges trying to figure out when async malloc was supported (it seems that for cards with compute <6 it fails) but have found an alternative approach where we compute the buffer size ahead of time and ask XLA to allocate. This is preferred for sure (although requires passing null pointers into cusparseSgtsv2_bufferSizeExt which seems to work today but I guess might change in future cuSPARSE releases).