343 Commits

Author SHA1 Message Date
Bill Varcho
f22bafac31 [SDY] remove TODO for enabling Layouts for Shardy post cl/697715276.
PiperOrigin-RevId: 700053383
2024-11-25 11:45:00 -08:00
Bill Varcho
bb1024f3fd [SDY] enable cpu_shardy for JAX shard_alike test.
PiperOrigin-RevId: 700029576
2024-11-25 10:33:17 -08:00
Bill Varcho
0ed6eaeb4a [SDY] fix JAX layouts tests for Shardy.
PiperOrigin-RevId: 697715276
2024-11-18 12:14:32 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
a1eb5ceade Merge pull request #23374 from jaro-sevcik:mock-topology-config
PiperOrigin-RevId: 696540499
2024-11-14 08:55:04 -08:00
jax authors
12c8c68c4a Merge pull request #24069 from sergachev:cudnn_fusion_test_a100
PiperOrigin-RevId: 696200281
2024-11-13 11:06:08 -08:00
Jaroslav Sevcik
eedd01118b Add an option to specify mock GPU topology 2024-11-12 08:36:27 -08:00
Peter Hawkins
7491fdd94c Disable for_loop_test on TPU v5p.
This test is failing in CI.

PiperOrigin-RevId: 695278007
2024-11-11 04:09:44 -08:00
Peter Hawkins
7285f10e84 Disable lax_test on ARM in Google's internal CI.
There are numerical errors from the complex plane function tests.

PiperOrigin-RevId: 694579368
2024-11-08 11:33:19 -08:00
Bill Varcho
afd8239ea4 [SDY] add JAX lowering to Shardy ShardingGroupOp for shard_alike.
PiperOrigin-RevId: 694567084
2024-11-08 11:02:50 -08:00
Peter Hawkins
3b2e4a1600 Remove sharding from custom_root_test.
This test only takes around 30s on most hardware platforms, it does not need 10 shards.

PiperOrigin-RevId: 694243316
2024-11-07 14:12:21 -08:00
Peter Hawkins
ea1e879577 Include mpmath as a bazel dependency of lax_test.
This test has additional test cases that require mpmath.

PiperOrigin-RevId: 693464078
2024-11-05 13:43:06 -08:00
Ilia Sergachev
e083c08001 Re-enable cudnn_fusion_test on A100.
Check that the required cuDNN version is available.
2024-11-01 15:48:07 +00:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Yash Katariya
e35e7f8e20 Allow sparsecore compute with T(8) layout via the layout API and compute_on API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore').
PiperOrigin-RevId: 691225280
2024-10-29 17:58:53 -07:00
Hyeontaek Lim
77797f434d [JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00
jax authors
1336c2d5c4 Fix breaking PGLE test-cases
PiperOrigin-RevId: 690608075
2024-10-28 07:50:31 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
Bart Chrzaszcz
fb32841b1b #sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
2024-10-11 09:44:24 -07:00
Peter Hawkins
66f526894f Reenable some test cases that were disabled due to bugs that now seem fixed.
PiperOrigin-RevId: 684464642
2024-10-10 09:06:06 -07:00
Peter Hawkins
19dbff5326 Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
2024-10-10 08:41:45 -07:00
George Necula
023f2a78be Remove remaining implementations of jax.experimental.host_callback.call.
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 683564340
2024-10-08 04:22:20 -07:00
Adam Paszke
7102c7adbf Bump the shard_count of FFT tests to avoid timeouts
PiperOrigin-RevId: 683537643
2024-10-08 02:44:41 -07:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
George Necula
b8a066a907 [host_callback] Remove obsolete tests.
Removing tests that only work in legacy mode and with outfeed.

PiperOrigin-RevId: 681435113
2024-10-02 06:51:02 -07:00
Peter Hawkins
1260ebbe05 Disable cudnn_fusion_test on A100.
This test only seems to pass on H100 at the moment.

PiperOrigin-RevId: 681070398
2024-10-01 10:18:41 -07:00
Ilia Sergachev
b320dc2e5e Fix and reenable cudnn_fusion_test.
Disable XLA autotuning fallback to cuBLAS so that the tested fusion
always executes through cuDNN.
2024-09-30 14:03:55 +00:00
Peter Hawkins
5969e79908 Fix tests that ask for an accelerator but don't use it.
* Delete custom_object_test, since it is disabled and has been ever since jax.Array was introduced in JAX 0.4.0.
* custom_linear_solve_test was over-sharded, leading to some shards not having any test cases. Even unsharded it completes in under 65s on every platform we have.
* config_test and pallas splash attention mask test only tested helpers and didn't need a TPU.

PiperOrigin-RevId: 679711664
2024-09-27 13:36:23 -07:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Bart Chrzaszcz
a3284bd8a3 #sdy Add CPU targets in JAX.
PiperOrigin-RevId: 679174535
2024-09-26 09:13:34 -07:00
Bart Chrzaszcz
e62a50cd34 #sdy add JAX Shardy support for shard_map.
For example the following JAX program:
```py
devices = np.array(jax.devices()[:8])
mesh = Mesh(devices, axis_names=('x'))
a = jax.device_put(
    jnp.arange(8 * 8).reshape((8, 8)),
    jax.sharding.NamedSharding(mesh, P('x', None)))

@jax.jit
@partial(
    shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None)
)
def fwd(a):
  axis_size = lax.psum(1, 'x')
  perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
  return lax.ppermute(a, 'x', perm=perm)

print(jax.jit(fwd).lower(a).as_text())
```

prints:

```cpp
module @jit_fwd attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <["x"=8]>
  func.func public @main(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fwd(%arg0) : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
  func.func private @fwd(%arg0: tensor<8x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<8x8xi32> {mhlo.layout_mode = "default"}) {
    %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"x"}, {}]>] out_shardings=[<@mesh, [{"x"}, {}]>] manual_axes={"x"} (%arg1: tensor<1x8xi32>) {
      %1 = "stablehlo.collective_permute"(%arg1) <{channel_handle = #stablehlo.channel_handle<handle = 1, type = 1>, source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 0]]> : tensor<8x2xi64>}> : (tensor<1x8xi32>) -> tensor<1x8xi32>
      sdy.return %1 : tensor<1x8xi32>
    } : (tensor<8x8xi32>) -> tensor<8x8xi32>
    return %0 : tensor<8x8xi32>
  }
}
```

PiperOrigin-RevId: 679165100
2024-09-26 08:45:40 -07:00
Peter Hawkins
1949413739 Increase sharding of checkify_test on TPU to fix CI flakes.
PiperOrigin-RevId: 678720498
2024-09-25 08:54:29 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
Peter Hawkins
85a466d730 Lower the shard count for sparse_bcoo_bcsr_test on TPU as well.
There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms.

PiperOrigin-RevId: 678367575
2024-09-24 13:10:32 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
jax authors
dc1ace5992 Re-enable tsan tests after fix.
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00
jax authors
28b5dee032 Disable flaky tsan tests temporarily.
PiperOrigin-RevId: 674338720
2024-09-13 10:03:24 -07:00
Peter Hawkins
95f38d95d7 Update TPU test configuration tags.
PiperOrigin-RevId: 672562923
2024-09-09 09:02:51 -07:00
Peter Hawkins
fe63b991dd Disable cudnn_fusion_test from CI.
This test isn't passing in our internal CI.

PiperOrigin-RevId: 672507574
2024-09-09 05:16:13 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Ilia Sergachev
85d792a92d Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. 2024-09-05 01:25:54 +02:00
jax authors
2f3990d13c Remove CPU test variant.
PiperOrigin-RevId: 669359594
2024-08-30 09:58:32 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Bart Chrzaszcz
71b7e78916 Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Tests fixed include:

- `test_globally_sharded_key_array_8x4_multi_device`
  - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
  - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
  - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
  - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
  - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
  - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

PiperOrigin-RevId: 666777167
2024-08-23 06:51:13 -07:00
George Necula
dbd6aeebb7 Disable some asan tests, times out
PiperOrigin-RevId: 662774152
2024-08-13 22:03:29 -07:00
Sergei Lebedev
d8eafc8ee3 Disabled nn_test under asan on TPU as well, since it also times out
PiperOrigin-RevId: 660950262
2024-08-08 13:02:31 -07:00
Sergei Lebedev
3a1567f57a Do not run nn_test under asan -- it times out
PiperOrigin-RevId: 660377176
2024-08-07 07:14:27 -07:00
Vladimir Belitskiy
7f96b263d4 Un-skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on ASAN.
It should have been fixed via
https://github.com/pytorch/pytorch/issues/117058#issuecomment-1973020150

PiperOrigin-RevId: 656464550
2024-07-26 11:10:41 -07:00
Vladimir Belitskiy
282ebf4882 Skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on MSAN.
MSAN has issues with using `-c opt` in some cases, which prevents this
test from running properly.

PiperOrigin-RevId: 656454585
2024-07-26 10:44:19 -07:00