367 Commits

Author SHA1 Message Date
Adam Paszke
8eea88626f Skip CPU ASAN in scipy special function tests
They just take too long under ASAN

PiperOrigin-RevId: 725965329
2025-02-12 02:13:25 -08:00
Dan Foreman-Mackey
c521bc6205 [xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.

In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.

The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.

In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.

PiperOrigin-RevId: 724367877
2025-02-07 09:14:06 -08:00
George Necula
a678396f44 Increase test shard_count for shape_poly_test on GPU
PiperOrigin-RevId: 723940915
2025-02-06 08:13:55 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Gunhyun Park
a8df383ccf Fix lax.ragged_all_to_all degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.

Added small improvement to error messages.

PiperOrigin-RevId: 721473063
2025-01-30 12:05:02 -08:00
Justin Fu
b01111d96c Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
2025-01-29 15:01:43 -08:00
Gunhyun Park
809e1133c8 Add support for axis_name and axis_index_groups to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
2025-01-28 16:02:03 -08:00
Peter Hawkins
faaaf82974 Disable pytorch_interoperability_test under asan on all backends.
It wasn't sufficient to disable this only on GPU.

PiperOrigin-RevId: 720344366
2025-01-27 16:18:28 -08:00
Peter Hawkins
42fd586e79 Disable pytorch_interoperability_test under asan.
PiperOrigin-RevId: 720189636
2025-01-27 09:00:40 -08:00
George Necula
e4d5427d13 [better_errors] Add more debug info test coverage
Try to cover the tracing of almost all JAX higher-order
primitives. Some of the tests added show missing debug info,
marked with TODO. Fixes will come separately.

Had to expand the helper functions _check_tracers_and_jaxprs to
use regular expressions for matching because some debug info
still contains non-deterministic elements.
2025-01-26 08:12:29 +02:00
Dan Foreman-Mackey
e3b3b913f7 Add an experimental interface for customizing DCE behavior.
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.

In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.

PiperOrigin-RevId: 718950828
2025-01-23 11:38:47 -08:00
George Necula
849ccc978b [better_errors] Expand the tests for debug_info
Debugging info is needed for error messages, and for
lowering. For the former, we need debug info inside
tracers. For the latter, inside Jaxprs. We add a
new set of tests that intentionally leak tracers while
tracing and then we check that the tracers have the
expected debug info. We also form Jaxprs and we
check that they have the expected debug info.
We uncovered a few missing debug infos, those are
marked with TODO.
2025-01-22 16:49:16 +01:00
George Necula
e5d89e738a [better_errors] Refactor debug info tests
Created debug_info_test.py and moved there some of the
tests involving debug_info. In the future we will put here
more tests for debugging info, and their helper functions.
2025-01-20 20:21:01 +01:00
Peter Hawkins
f122f17b27 Rename test configs to include GPU variants more consistently.
* Include "p100" or "v100" in the default "gpu" config names, matching their current CI configuration.
* Rename "_2gpu" test variants to "x2" variants, since this is more succinct.

This change is intended to be a pure renaming, and it is not intended to alter the set of tests that run.

PiperOrigin-RevId: 715468944
2025-01-14 11:55:45 -08:00
Bart Chrzaszcz
dc53c563bb #sdy enable pure callbacks and debug prints in JAX.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
2025-01-09 13:37:51 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Berkin Ilbeyi
f17b2bc2d3 Reenable for_loop_test on TPU v5p.
PiperOrigin-RevId: 704298792
2024-12-09 08:38:41 -08:00
Bixia Zheng
2a4a0e8d6f [jax:custom_partitioning] Implement SdyShardingRule to support
Shardy custom_partitioning.

The parsing of the sharding rule string very closely follows how einops parses
their rules in einops/parsing.py.

When a SdyShardingRule object is constructed, we check the syntax of the Einsum
like notation string and its consistency with the user provided factor_sizes,
and report errors accordingly. This is done during f.def_partition.

When SdyShardingRule.build is called, during JAX to MLIR lowering, we check
the consistency between the Einsum like notation string, the factor_sizes
and the MLIR operation, and report errors accordingly.

PiperOrigin-RevId: 703187962
2024-12-05 11:33:23 -08:00
Hyeontaek Lim
e20a483bef [JAX] Add end-to-end execution support in colocated Python API
This change adds a capability to run colocated Python function calls through
`PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested
with a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.

Key highlights:

* Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++
dispatch path.

* `CustomCallProgram` for a colocated Python compilation nows includes
specialization (input/output specs, devices). This information allows a
colocated Python backend to transform input/outputs and validate
PyTree/dtype/shape/sharding.

* `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values.

* Deserialization of devices now prefers the default backend. This improves the
compatibility with an environment using both multi-platform backend as well as
the standard "cpu" backend at the same time.

* Several bugs have been fixed (e.g., correctly using `{}` for kwargs).

PiperOrigin-RevId: 703172997
2024-12-05 10:52:40 -08:00
Enrique Piqueras
8c521547b7
Add experimental JAX roofline API. 2024-11-27 14:38:57 -08:00
Hyeontaek Lim
bbaec6ea59 [JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
2024-11-26 13:31:15 -08:00
Bill Varcho
f22bafac31 [SDY] remove TODO for enabling Layouts for Shardy post cl/697715276.
PiperOrigin-RevId: 700053383
2024-11-25 11:45:00 -08:00
Bill Varcho
bb1024f3fd [SDY] enable cpu_shardy for JAX shard_alike test.
PiperOrigin-RevId: 700029576
2024-11-25 10:33:17 -08:00
Bill Varcho
0ed6eaeb4a [SDY] fix JAX layouts tests for Shardy.
PiperOrigin-RevId: 697715276
2024-11-18 12:14:32 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
a1eb5ceade Merge pull request #23374 from jaro-sevcik:mock-topology-config
PiperOrigin-RevId: 696540499
2024-11-14 08:55:04 -08:00
jax authors
12c8c68c4a Merge pull request #24069 from sergachev:cudnn_fusion_test_a100
PiperOrigin-RevId: 696200281
2024-11-13 11:06:08 -08:00
Jaroslav Sevcik
eedd01118b Add an option to specify mock GPU topology 2024-11-12 08:36:27 -08:00
Peter Hawkins
7491fdd94c Disable for_loop_test on TPU v5p.
This test is failing in CI.

PiperOrigin-RevId: 695278007
2024-11-11 04:09:44 -08:00
Peter Hawkins
7285f10e84 Disable lax_test on ARM in Google's internal CI.
There are numerical errors from the complex plane function tests.

PiperOrigin-RevId: 694579368
2024-11-08 11:33:19 -08:00
Bill Varcho
afd8239ea4 [SDY] add JAX lowering to Shardy ShardingGroupOp for shard_alike.
PiperOrigin-RevId: 694567084
2024-11-08 11:02:50 -08:00
Peter Hawkins
3b2e4a1600 Remove sharding from custom_root_test.
This test only takes around 30s on most hardware platforms, it does not need 10 shards.

PiperOrigin-RevId: 694243316
2024-11-07 14:12:21 -08:00
Peter Hawkins
ea1e879577 Include mpmath as a bazel dependency of lax_test.
This test has additional test cases that require mpmath.

PiperOrigin-RevId: 693464078
2024-11-05 13:43:06 -08:00
Ilia Sergachev
e083c08001 Re-enable cudnn_fusion_test on A100.
Check that the required cuDNN version is available.
2024-11-01 15:48:07 +00:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Yash Katariya
e35e7f8e20 Allow sparsecore compute with T(8) layout via the layout API and compute_on API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore').
PiperOrigin-RevId: 691225280
2024-10-29 17:58:53 -07:00
Hyeontaek Lim
77797f434d [JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00
jax authors
1336c2d5c4 Fix breaking PGLE test-cases
PiperOrigin-RevId: 690608075
2024-10-28 07:50:31 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
Bart Chrzaszcz
fb32841b1b #sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
2024-10-11 09:44:24 -07:00
Peter Hawkins
66f526894f Reenable some test cases that were disabled due to bugs that now seem fixed.
PiperOrigin-RevId: 684464642
2024-10-10 09:06:06 -07:00
Peter Hawkins
19dbff5326 Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
2024-10-10 08:41:45 -07:00
George Necula
023f2a78be Remove remaining implementations of jax.experimental.host_callback.call.
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 683564340
2024-10-08 04:22:20 -07:00
Adam Paszke
7102c7adbf Bump the shard_count of FFT tests to avoid timeouts
PiperOrigin-RevId: 683537643
2024-10-08 02:44:41 -07:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
George Necula
b8a066a907 [host_callback] Remove obsolete tests.
Removing tests that only work in legacy mode and with outfeed.

PiperOrigin-RevId: 681435113
2024-10-02 06:51:02 -07:00