377 Commits

Author SHA1 Message Date
jax authors
8af6f70fe0 [JAX] Disable msan and asan for the profiler test running on nvidia gpu
PiperOrigin-RevId: 733380848
2025-03-04 10:34:11 -08:00
Bart Chrzaszcz
ac493655bf #sdy support JAX export tests when Shardy is enabled.
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.

Note that we will be introducing some restrictions under Shardy for JAX export:

- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)

We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.

PiperOrigin-RevId: 732878916
2025-03-03 04:57:06 -08:00
jax authors
48a55a6d71 Add a profiler test for gpu run
PiperOrigin-RevId: 732247572
2025-02-28 13:45:46 -08:00
jax authors
c7ca35fe32 Merge pull request #26345 from wenscarl:scaled_matmul
PiperOrigin-RevId: 731865430
2025-02-27 14:24:48 -08:00
Dan Foreman-Mackey
f93c2a1aa5 Add and test support for partitioning of batch dimensions in lax.linalg.
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
2025-02-27 08:16:16 -08:00
jax authors
615219b1f6 Remove tensorstore dependency from //jax/experimental/array_serialization:serialization in OSS (see https://github.com/google/tensorstore/issues/218)
Disable serialization_test in OSS.

PiperOrigin-RevId: 731463136
2025-02-26 14:47:16 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
Peter Hawkins
673a02d614 Don't set PYTHONWARNINGS=error for tests that use TensorFlow.
Protobuf, which is used by TF, sometimes emits a warning under Python 3.12.

PiperOrigin-RevId: 729554740
2025-02-21 09:22:59 -08:00
Zac Cranko
5db78e7ae0 add distributed.is_initialized 2025-02-18 16:47:19 -08:00
Ayaka
6addf02add Add JAX error checking support
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.

PiperOrigin-RevId: 726920440
2025-02-14 07:28:21 -08:00
Adam Paszke
8eea88626f Skip CPU ASAN in scipy special function tests
They just take too long under ASAN

PiperOrigin-RevId: 725965329
2025-02-12 02:13:25 -08:00
Dan Foreman-Mackey
c521bc6205 [xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.

In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.

The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.

In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.

PiperOrigin-RevId: 724367877
2025-02-07 09:14:06 -08:00
George Necula
a678396f44 Increase test shard_count for shape_poly_test on GPU
PiperOrigin-RevId: 723940915
2025-02-06 08:13:55 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Gunhyun Park
a8df383ccf Fix lax.ragged_all_to_all degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.

Added small improvement to error messages.

PiperOrigin-RevId: 721473063
2025-01-30 12:05:02 -08:00
Justin Fu
b01111d96c Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
2025-01-29 15:01:43 -08:00
Gunhyun Park
809e1133c8 Add support for axis_name and axis_index_groups to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
2025-01-28 16:02:03 -08:00
Peter Hawkins
faaaf82974 Disable pytorch_interoperability_test under asan on all backends.
It wasn't sufficient to disable this only on GPU.

PiperOrigin-RevId: 720344366
2025-01-27 16:18:28 -08:00
Peter Hawkins
42fd586e79 Disable pytorch_interoperability_test under asan.
PiperOrigin-RevId: 720189636
2025-01-27 09:00:40 -08:00
George Necula
e4d5427d13 [better_errors] Add more debug info test coverage
Try to cover the tracing of almost all JAX higher-order
primitives. Some of the tests added show missing debug info,
marked with TODO. Fixes will come separately.

Had to expand the helper functions _check_tracers_and_jaxprs to
use regular expressions for matching because some debug info
still contains non-deterministic elements.
2025-01-26 08:12:29 +02:00
Dan Foreman-Mackey
e3b3b913f7 Add an experimental interface for customizing DCE behavior.
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.

In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.

PiperOrigin-RevId: 718950828
2025-01-23 11:38:47 -08:00
George Necula
849ccc978b [better_errors] Expand the tests for debug_info
Debugging info is needed for error messages, and for
lowering. For the former, we need debug info inside
tracers. For the latter, inside Jaxprs. We add a
new set of tests that intentionally leak tracers while
tracing and then we check that the tracers have the
expected debug info. We also form Jaxprs and we
check that they have the expected debug info.
We uncovered a few missing debug infos, those are
marked with TODO.
2025-01-22 16:49:16 +01:00
George Necula
e5d89e738a [better_errors] Refactor debug info tests
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
2025-01-20 20:21:01 +01:00
Peter Hawkins
f122f17b27 Rename test configs to include GPU variants more consistently.
* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration.
* Rename "_2gpu" test variants to "x2" variants, since this is more succinct.

This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run.

PiperOrigin-RevId: 715468944
2025-01-14 11:55:45 -08:00
Bart Chrzaszcz
dc53c563bb #sdy enable pure callbacks and debug prints in JAX.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
2025-01-09 13:37:51 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Berkin Ilbeyi
f17b2bc2d3 Reenable for_loop_test on TPU v5p.
PiperOrigin-RevId: 704298792
2024-12-09 08:38:41 -08:00
Bixia Zheng
2a4a0e8d6f [jax:custom_partitioning] Implement SdyShardingRule to support
Shardy custom_partitioning.

The parsing of the sharding rule string very closely follows how einops parses
their rules in einops/parsing.py.

When a SdyShardingRule object is constructed, we check the syntax of the Einsum
like notation string and its consistency with the user provided factor_sizes,
and report errors accordingly. This is done during f.def_partition.

When SdyShardingRule.build is called, during JAX to MLIR lowering, we check
the consistency between the Einsum like notation string, the factor_sizes
and the MLIR operation, and report errors accordingly.

PiperOrigin-RevId: 703187962
2024-12-05 11:33:23 -08:00
Hyeontaek Lim
e20a483bef [JAX] Add end-to-end execution support in colocated Python API
This change adds a capability to run colocated Python function calls through
`PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested
with a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.

Key highlights:

* Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++
dispatch path.

* `CustomCallProgram` for a colocated Python compilation nows includes
specialization (input/output specs, devices). This information allows a
colocated Python backend to transform input/outputs and validate
PyTree/dtype/shape/sharding.

* `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values.

* Deserialization of devices now prefers the default backend. This improves the
compatibility with an environment using both multi-platform backend as well as
the standard "cpu" backend at the same time.

* Several bugs have been fixed (e.g., correctly using `{}` for kwargs).

PiperOrigin-RevId: 703172997
2024-12-05 10:52:40 -08:00
Enrique Piqueras
8c521547b7
Add experimental JAX roofline API. 2024-11-27 14:38:57 -08:00
Hyeontaek Lim
bbaec6ea59 [JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
2024-11-26 13:31:15 -08:00
Bill Varcho
f22bafac31 [SDY] remove TODO for enabling Layouts for Shardy post cl/697715276.
PiperOrigin-RevId: 700053383
2024-11-25 11:45:00 -08:00
Bill Varcho
bb1024f3fd [SDY] enable cpu_shardy for JAX shard_alike test.
PiperOrigin-RevId: 700029576
2024-11-25 10:33:17 -08:00
Bill Varcho
0ed6eaeb4a [SDY] fix JAX layouts tests for Shardy.
PiperOrigin-RevId: 697715276
2024-11-18 12:14:32 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
a1eb5ceade Merge pull request #23374 from jaro-sevcik:mock-topology-config
PiperOrigin-RevId: 696540499
2024-11-14 08:55:04 -08:00
jax authors
12c8c68c4a Merge pull request #24069 from sergachev:cudnn_fusion_test_a100
PiperOrigin-RevId: 696200281
2024-11-13 11:06:08 -08:00
Jaroslav Sevcik
eedd01118b Add an option to specify mock GPU topology 2024-11-12 08:36:27 -08:00
Peter Hawkins
7491fdd94c Disable for_loop_test on TPU v5p.
This test is failing in CI.

PiperOrigin-RevId: 695278007
2024-11-11 04:09:44 -08:00
Peter Hawkins
7285f10e84 Disable lax_test on ARM in Google's internal CI.
There are numerical errors from the complex plane function tests.

PiperOrigin-RevId: 694579368
2024-11-08 11:33:19 -08:00
Bill Varcho
afd8239ea4 [SDY] add JAX lowering to Shardy ShardingGroupOp for shard_alike.
PiperOrigin-RevId: 694567084
2024-11-08 11:02:50 -08:00
Peter Hawkins
3b2e4a1600 Remove sharding from custom_root_test.
This test only takes around 30s on most hardware platforms, it does not need 10 shards.

PiperOrigin-RevId: 694243316
2024-11-07 14:12:21 -08:00
Peter Hawkins
ea1e879577 Include mpmath as a bazel dependency of lax_test.
This test has additional test cases that require mpmath.

PiperOrigin-RevId: 693464078
2024-11-05 13:43:06 -08:00
Ilia Sergachev
e083c08001 Re-enable cudnn_fusion_test on A100.
Check that the required cuDNN version is available.
2024-11-01 15:48:07 +00:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Yash Katariya
e35e7f8e20 Allow sparsecore compute with T(8) layout via the layout API and compute_on API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore').
PiperOrigin-RevId: 691225280
2024-10-29 17:58:53 -07:00