Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.
PiperOrigin-RevId: 642423757
XLA:CPU is migrating from compiling monolithic LLVM function for the whole HLO module to a thin runtime with separate functions for each kernel (fusion, individual operation, library call, etc.). While new runtime is not enabled by default we will use explicit opt-in on tests that are already compatible.
This tag will be removed after XLA:CPU will switch to the new runtime by default.
PiperOrigin-RevId: 640022517
The FFI headers aren't properly exposed during a bazel build, so these
tests are failing. I'll re-enable them when I get a chance to get that
working properly.
This test recently started exceeding the default memory requirement when run under tsan. I'm not entirely sure why, but perhaps some change pushed it just over our CI's 12GB default limit.
PiperOrigin-RevId: 636910434
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.
It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.
All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.
`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.
PiperOrigin-RevId: 634909918
Metadata loader was using incorrect warp assignment, which resulted in incorrect addresses with num_warps>4. This was previously missed, as the autotuner rarely selected such configs.
PiperOrigin-RevId: 633513110
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.
This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))
where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type
If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.
PiperOrigin-RevId: 627640553
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases
PiperOrigin-RevId: 608534008