The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
As described in https://github.com/google/jax/issues/21643, we're seeing
test failures in one `pmap` test under very specific circumstances. I
haven't been able to solve the issue, or even track down the original
source, since the failure has only been reproduced when running the full
test suite with `pytest`. Instead, this PR makes this test more lenient,
testing that grad-of-pmap produces the appropriate cache hits when used
a second time, rather than also checking the total number of `pmap` and
`jit` lowerings required.
The default thread pool size is too small on Mac OS.
An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite.
Fixes https://github.com/google/jax/issues/20428
PiperOrigin-RevId: 620853544
The StateContextManager restores its thread-local state to None, which means that the
initial thread-local state must also be None if the context manager is
to correctly restore the initial state.
This caused a test failure in a test case in pmap_test which checked for
exactly one cache entry across threads. One thread had used the
softmax_custom_jvp context manager, and had a different state (None)
instead of False.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.
i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.
Why do this?
The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.
The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.
This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.
Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.
The change is disabled by default, so we do not expect any user visible impacts from this change.
PiperOrigin-RevId: 599787818
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.
PiperOrigin-RevId: 598550225
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).
If your code ends up taking the python dispatch, then something is going wrong anyways.
PiperOrigin-RevId: 596081987
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.
Fixes https://github.com/google/jax/issues/18122
PiperOrigin-RevId: 573955671
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.
Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().
PiperOrigin-RevId: 568923117
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.
Remove an unused top_k translation rule as well.
PiperOrigin-RevId: 554946059
pxla.replicate() can be replaced by jax.device_put_replicated().
No deprecation period because jax.interpreters APIs are not stable.
PiperOrigin-RevId: 553502827
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.
Example of new output:
```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
return %0 : tensor<i32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```
Note debug information must be enabled.
PiperOrigin-RevId: 549325621
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.
Happily, this also allows us to remove a number of TPU special cases for the same reason.
PiperOrigin-RevId: 539101155
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).
TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.