357 Commits

Author SHA1 Message Date
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
Matthew Johnson
88d1cd731d remove pdot and xeinsum (since xmap is gone) 2024-07-25 21:19:17 +00:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
jax authors
5019167106 Further reduce the run time of pmap_test
PiperOrigin-RevId: 643548103
2024-06-14 23:26:14 -07:00
jax authors
06ec7d1ad5 Reduce the matrix size in testPmapMapVmapCombinations to reduce the test run time.
PiperOrigin-RevId: 643166085
2024-06-13 17:02:03 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Dan Foreman-Mackey
a46d5c2a30 Simplify flaky test of grad-of-pmap cache hits
As described in https://github.com/google/jax/issues/21643, we're seeing
test failures in one `pmap` test under very specific circumstances. I
haven't been able to solve the issue, or even track down the original
source, since the failure has only been reproduced when running the full
test suite with `pytest`. Instead, this PR makes this test more lenient,
testing that grad-of-pmap produces the appropriate cache hits when used
a second time, rather than also checking the total number of `pmap` and
`jit` lowerings required.
2024-06-04 16:49:11 -04:00
Jake VanderPlas
f5a59ccd4d Fix warning filter in test_pmap_of_prng_key 2024-05-15 13:20:05 -07:00
Jake VanderPlas
5150cfeeb0 Fix PRNGKey handling under jit-of-pmap 2024-05-13 19:04:22 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Peter Hawkins
011ced4431 Guard test that requires two devices with device_count() check.
PiperOrigin-RevId: 620921563
2024-04-01 12:32:54 -07:00
Peter Hawkins
77ff8a2339 [PJRT:CPU] Fix thread-pool stack sizes to 2MiB.
The default thread pool size is too small on Mac OS.

An older version of this runtime based on StreamExecutor set a 2MiB stack size as well, but that change was most likely lost during the TFRT rewrite.

Fixes https://github.com/google/jax/issues/20428

PiperOrigin-RevId: 620853544
2024-04-01 08:20:36 -07:00
Peter Hawkins
8815b236b6 Disable collective broadcast pmap test on older jaxlibs.
Collective broadcast was only recently added to xla.

PiperOrigin-RevId: 620287470
2024-03-29 10:59:10 -07:00
Chase Roberts
01412f7645 pbroadcast 2024-03-18 15:12:33 -07:00
Yash Katariya
67b0eb3af4 Improve pytree mismatch error in AOT
PiperOrigin-RevId: 612560820
2024-03-04 13:15:32 -08:00
Jake VanderPlas
d08e9a03d8 [key reuse] add eager checks 2024-02-29 15:30:19 -08:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Peter Hawkins
d713f3a632 Fix test flake which occurred due to a spurious cache misses.
The StateContextManager restores its thread-local state to None, which means that the
initial thread-local state must also be None if the context manager is
to correctly restore the initial state.

This caused a test failure in a test case in pmap_test which checked for
exactly one cache entry across threads. One thread had used the
softmax_custom_jvp context manager, and had a different state (None)
instead of False.
2024-02-20 22:54:31 +00:00
Peter Hawkins
fc6df3218c Add a new experimental option jax_pmap_no_rank_reduction.
This option changes the implementation of pmap so that the individual shards have the same rank as the entire array, i.e. in the terminology of pmap using a "chunked" axis instead of an "unstacked" axis.

i.e., previously a typical array used by pmap might have a shape of, say, [8, 100], if sharded across 8 accelerators on its first axis, and each individual shard would have a shape of, say, [100]. With this change, each individual shard has a shape of [1, 100] instead.

Why do this?

The main reason to do this is that XLA's sharding (HloSharding), which is exposed in JAX as GSPMDSharding/NamedSharding/PositionalSharding, cannot represent a change of rank. This means that the kind of sharding used by pmap cannot be represented to XLA as a sharding. If we change the definition of PmapSharding to preserve the array rank instead, then this means that PmapSharding can in the future be represented directly as a kind of sharding known to XLA.

The new definition of PmapSharding will allow a number of internal simplifications to JAX, for example in a subsequent change we can probably delete PmapSharding entirely. This in turn also would allow us to delete the APIs `jax.device_put_replicated` and `jax.device_put_sharded`, which predate the current sharding design.

This change also prepares for an upcoming change where we would like to redefine `pmap` in terms of `jit(shard_map(...))`, allowing us to delete most `pmap` code paths.

Once enabled, this change has the potential to break pmap users who:
a) look at the shards of an array, e.g., via `.addressable_shards`, or `jax.make_array_from_single_device_arrays`, since the shapes of the shards will change.
b) rely on zero-copy behavior in APIs like `jax.device_put_replicated`.

The change is disabled by default, so we do not expect any user visible impacts from this change.

PiperOrigin-RevId: 599787818
2024-01-19 03:53:37 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
Yash Katariya
c0d4653fc9 Delete sharding spec to HloSharding conversion since it's not used anymore.
PiperOrigin-RevId: 595192496
2024-01-02 13:13:23 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
d2b4800723 tests: improve warnings-related tests 2023-11-30 10:35:24 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Peter Hawkins
f4eb3f6d86 Add a regression test for a pmap issue that is fixed at head.
Fixes https://github.com/google/jax/issues/5757

PiperOrigin-RevId: 580243825
2023-11-07 11:21:21 -08:00
Peter Hawkins
89b5449882 [XLA:GPU] Fix bug in all-to-all for complex data types.
The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier.

Fixes https://github.com/google/jax/issues/18122

PiperOrigin-RevId: 573955671
2023-10-16 16:02:22 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
5aaa15df84 Remove the skip_on_xla_cpu_mlir decorator.
We no longer test this variant in CI, so we don't need code to skip it.

PiperOrigin-RevId: 568219651
2023-09-25 08:04:56 -07:00
Berkin Ilbeyi
c9b5996f59 [XLA] Initialize tuple shapes of async-done in dataflow analysis.
PiperOrigin-RevId: 567724401
2023-09-22 14:59:31 -07:00
Yash Katariya
03877a9218 If a pmap out is replicated i.e. with out_axes=None make jnp.copy's impl go via apply_primitive which will put it on a single device.
If we don't do that, then it hits an error mentioned in https://github.com/google/jax/issues/17690.

Fixes https://github.com/google/jax/issues/17690

PiperOrigin-RevId: 567628026
2023-09-22 08:24:57 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Peter Hawkins
26727ea713 Delete jax.interpreters.pxla.replicate().
pxla.replicate() can be replaced by jax.device_put_replicated().

No deprecation period because jax.interpreters APIs are not stable.

PiperOrigin-RevId: 553502827
2023-08-03 09:37:00 -07:00
Peter Hawkins
7df3477926 [JAX] Use MLIR argument locations instead of a bespoke jax.arg_info attribute.
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.

Example of new output:

```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
    return %0 : tensor<i32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```

Note debug information must be enabled.

PiperOrigin-RevId: 549325621
2023-07-19 08:39:16 -07:00
jax authors
e894e4817a Remove deprecated compiler_ir from Compiled
PiperOrigin-RevId: 547211085
2023-07-11 09:24:48 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
Peter Hawkins
803c729b57 Fix some test failures under H100.
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.

Happily, this also allows us to remove a number of TPU special cases for the same reason.

PiperOrigin-RevId: 539101155
2023-06-09 09:23:36 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Jake VanderPlas
9cfe77d5e1 Remove use of deprecated make_sharded_device_array 2023-05-03 10:11:29 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
3722d7066a Add jax_pmap_shmap_merge flag to begin the process of merging pmap and shard_map
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).

TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
2023-04-20 21:22:48 -07:00
Yash Katariya
53e6382f4a Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
2023-04-19 15:08:53 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00