316 Commits

Author SHA1 Message Date
Ilia Sergachev
b320dc2e5e Fix and reenable cudnn_fusion_test.
Disable XLA autotuning fallback to cuBLAS so that the tested fusion
always executes through cuDNN.
2024-09-30 14:03:55 +00:00
Peter Hawkins
5969e79908 Fix tests that ask for an accelerator but don't use it.
* Delete custom_object_test, since it is disabled and has been ever since jax.Array was introduced in JAX 0.4.0.
* custom_linear_solve_test was over-sharded, leading to some shards not having any test cases. Even unsharded it completes in under 65s on every platform we have.
* config_test and pallas splash attention mask test only tested helpers and didn't need a TPU.

PiperOrigin-RevId: 679711664
2024-09-27 13:36:23 -07:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Bart Chrzaszcz
a3284bd8a3 #sdy Add CPU targets in JAX.
PiperOrigin-RevId: 679174535
2024-09-26 09:13:34 -07:00
Bart Chrzaszcz
e62a50cd34 #sdy add JAX Shardy support for shard_map.
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
    jnp.arange(8 * 8).reshape((8, 8)),
    jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(
    shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
  axis_size = lax.psum(1, 'x')
  perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
  return lax.ppermute(a, 'x', perm=perm)

print(jax.jit(fwd).lower(a).as_text())
```

prints:

```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=8]>
  func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
  func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
      %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
      sdy.return %1 : tensor<1x8xi32>
    } : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
}
```

PiperOrigin-RevId: 679165100
2024-09-26 08:45:40 -07:00
Peter Hawkins
1949413739 Increase sharding of checkify_test on TPU to fix CI flakes.
PiperOrigin-RevId: 678720498
2024-09-25 08:54:29 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
Peter Hawkins
85a466d730 Lower the shard count for sparse_bcoo_bcsr_test on TPU as well.
There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms.

PiperOrigin-RevId: 678367575
2024-09-24 13:10:32 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
jax authors
dc1ace5992 Re-enable tsan tests after fix.
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00
jax authors
28b5dee032 Disable flaky tsan tests temporarily.
PiperOrigin-RevId: 674338720
2024-09-13 10:03:24 -07:00
Peter Hawkins
95f38d95d7 Update TPU test configuration tags.
PiperOrigin-RevId: 672562923
2024-09-09 09:02:51 -07:00
Peter Hawkins
fe63b991dd Disable cudnn_fusion_test from CI.
This test isn't passing in our internal CI.

PiperOrigin-RevId: 672507574
2024-09-09 05:16:13 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Ilia Sergachev
85d792a92d Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. 2024-09-05 01:25:54 +02:00
jax authors
2f3990d13c Remove CPU test variant.
PiperOrigin-RevId: 669359594
2024-08-30 09:58:32 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Bart Chrzaszcz
71b7e78916 Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Tests fixed include:

- `test_globally_sharded_key_array_8x4_multi_device`
  - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
  - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
  - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
  - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
  - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
  - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

PiperOrigin-RevId: 666777167
2024-08-23 06:51:13 -07:00
George Necula
dbd6aeebb7 Disable some asan tests, times out
PiperOrigin-RevId: 662774152
2024-08-13 22:03:29 -07:00
Sergei Lebedev
d8eafc8ee3 Disabled nn_test under asan on TPU as well, since it also times out
PiperOrigin-RevId: 660950262
2024-08-08 13:02:31 -07:00
Sergei Lebedev
3a1567f57a Do not run nn_test under asan -- it times out
PiperOrigin-RevId: 660377176
2024-08-07 07:14:27 -07:00
Vladimir Belitskiy
7f96b263d4 Un-skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on ASAN.
It should have been fixed via
https://github.com/pytorch/pytorch/issues/117058#issuecomment-1973020150

PiperOrigin-RevId: 656464550
2024-07-26 11:10:41 -07:00
Vladimir Belitskiy
282ebf4882 Skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on MSAN.
MSAN has issues with using `-c opt` in some cases, which prevents this
test from running properly.

PiperOrigin-RevId: 656454585
2024-07-26 10:44:19 -07:00
Vladimir Belitskiy
ba50e77407 Increase shard count for //third_party/py/jax/tests:lax_numpy_ufuncs_test_cpu.
PiperOrigin-RevId: 655946922
2024-07-25 07:26:06 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Eugene Zhulenev
e3fc63cafb [xla:cpu] Support for up to 16 sorted inputs
+ enable more jax/lax tests for XLA CPU thunks

PiperOrigin-RevId: 655249641
2024-07-23 11:54:31 -07:00
Vladimir Belitskiy
a1f2a50cfa Increase shard count under TPU for //third_party/py/jax/tests:lax_numpy_test.
PiperOrigin-RevId: 654847718
2024-07-22 12:08:04 -07:00
Eugene Zhulenev
e23de7d790 [xla:cpu] Add support for XLA:CPU+Thunks compilation cache
+ fix custom call thunk for tuple result of size 1

PiperOrigin-RevId: 653381010
2024-07-17 15:29:13 -07:00
Gleb Pobudzey
46103f6ff3 Updated the repr of GPU devices to be more consistent with TPUs/CPUs.
For example, `cuda(id=0)` will now be `CudaDevice(id=0)`

PiperOrigin-RevId: 651393690
2024-07-11 06:54:20 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Ayaka
6c05aa2f32 Clean up 2024-07-04 17:16:32 +04:00
Adam Paszke
727d120401 Bump up the shard_count for GPU FFT tests
They seem to be timing out with ASAN and no sharding.

PiperOrigin-RevId: 648301571
2024-07-01 02:58:23 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
George Necula
24b42eed5e [export] Clean up BUILD targets for jax.experimental.export
jax.experimental.export is deprecated and will be removed in a future version of JAX.

See migration guide at: https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export

PiperOrigin-RevId: 647562073
2024-06-27 23:08:48 -07:00
jax authors
8b22b5fbcc Merge pull request #22087 from jakevdp:validate-config
PiperOrigin-RevId: 646963494
2024-06-26 08:51:34 -07:00
Jake VanderPlas
fa73077146 jax.config: validate on set() 2024-06-25 09:02:32 -07:00
Justin Fu
8ba8f3bf65 [Pallas] Implement block-invariant sampling.
PiperOrigin-RevId: 646161271
2024-06-24 11:20:39 -07:00
Eugene Zhulenev
3fd9326881 [jax] Enable api_test with XLA:CPU thunks
PiperOrigin-RevId: 644268375
2024-06-17 23:58:02 -07:00
Peter Hawkins
339027d7ab [JAX] Disable qdwh_test in asan/msan/tsan configurations on TPU.
This test is flakily timing out in CI, the sanitizers probably push the test over its time bound.

PiperOrigin-RevId: 642695381
2024-06-12 12:16:50 -07:00
jax authors
95e2c17b61 Update test of QDWH to use stricter tolerances and test more shapes and types.
Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.

PiperOrigin-RevId: 642423757
2024-06-11 16:04:38 -07:00
Ayaka
1a3a15c9e3 Implement LRU cache eviction for persistent compilation cache
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
Eugene Zhulenev
4cca233522 [xla:cpu] Add support for outfeed to thunk runtime
+ enabled infeed/outfeed test in jax

PiperOrigin-RevId: 640026208
2024-06-03 22:43:42 -07:00
Eugene Zhulenev
1aa2b6ee4f [jax:cpu] Add a test config to run JAX with new XLA:CPU runtime
XLA:CPU is migrating from compiling monolithic LLVM function for the whole HLO module to a thin runtime with separate functions for each kernel (fusion, individual operation, library call, etc.). While new runtime is not enabled by default we will use explicit opt-in on tests that are already compatible.

This tag will be removed after XLA:CPU will switch to the new runtime by default.

PiperOrigin-RevId: 640022517
2024-06-03 22:28:18 -07:00
Dan Foreman-Mackey
690fa1d90c Remove failing ffi test
The FFI headers aren't properly exposed during a bazel build, so these
tests are failing. I'll re-enable them when I get a chance to get that
working properly.
2024-05-31 15:36:33 -04:00
jax authors
d9f07d0350 Merge pull request #21531 from dfm:move-ffi-submodule
PiperOrigin-RevId: 639077602
2024-05-31 10:31:32 -07:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Peter Hawkins
01b4cb6de0 Bump memory limits for array_test and layout_test on TPU CI.
These use more than our CI's default memory limit (12GB) when run under tsan.

PiperOrigin-RevId: 638618718
2024-05-30 05:33:14 -07:00
Peter Hawkins
1e01fa7b0f Bump up memory requirements for pmap_test on TPU.
This test recently started exceeding the default memory requirement when run under tsan. I'm not entirely sure why, but perhaps some change pushed it just over our CI's 12GB default limit.

PiperOrigin-RevId: 636910434
2024-05-24 07:25:30 -07:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Peter Hawkins
483f924ea1 Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN.
PiperOrigin-RevId: 635850400
2024-05-21 10:25:24 -07:00