Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.
PiperOrigin-RevId: 713411171
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.
In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.
PiperOrigin-RevId: 713272197
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.
Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.
This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.
Fixes https://github.com/jax-ml/jax/issues/24521
PiperOrigin-RevId: 694274616
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.
There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.
PiperOrigin-RevId: 693114411
The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
As described in https://github.com/google/jax/issues/21643, we're seeing
test failures in one `pmap` test under very specific circumstances. I
haven't been able to solve the issue, or even track down the original
source, since the failure has only been reproduced when running the full
test suite with `pytest`. Instead, this PR makes this test more lenient,
testing that grad-of-pmap produces the appropriate cache hits when used
a second time, rather than also checking the total number of `pmap` and
`jit` lowerings required.
The default thread pool size is too small on Mac OS.
An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite.
Fixes https://github.com/google/jax/issues/20428
PiperOrigin-RevId: 620853544
The StateContextManager restores its thread-local state to None, which means that the
initial thread-local state must also be None if the context manager is
to correctly restore the initial state.
This caused a test failure in a test case in pmap_test which checked for
exactly one cache entry across threads. One thread had used the
softmax_custom_jvp context manager, and had a different state (None)
instead of False.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.
Fixes https://github.com/google/jax/issues/18122
PiperOrigin-RevId: 573955671
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117