270 Commits

Author SHA1 Message Date
Adam Paszke
d2f937e241 Make jax.Arrays a necessary part of the cycle in the GC guard test
Otherwise, the cycle can be broken by clearing the references of the helper
objects, at which points the deallocation of arrays proceeds through regular
reference counting (and does not trigger logs!). I have not verified that
this is what happens, but the test has been mysteriously failing under a
number of configurations and this seems to fix it.

I added a note to the garbage collection guard to clarify that it's not
guaranteed to report all cycles.

PiperOrigin-RevId: 708320953
2024-12-20 07:48:04 -08:00
Yash Katariya
8b734808e8 Remove jax_enable_memories config flag. It defaulted to True for a very long time and it's time to remove the flag.
PiperOrigin-RevId: 707590263
2024-12-18 10:15:45 -08:00
Matthew Johnson
42ac4ca357 ref errors 2024-12-18 07:46:14 +00:00
Yash Katariya
1e22149493 Fix the breakage caused by deleted enable_memories config
PiperOrigin-RevId: 707331603
2024-12-17 18:17:13 -08:00
Yash Katariya
cca9afa28f Delete enable_memories code in C++ since that flag is always True and cannot be turned off now.
PiperOrigin-RevId: 707298305
2024-12-17 16:43:20 -08:00
jax authors
a123d4e39e Remove autotune sharing.
xla_gpu_shard_autotuning can be used now instead and it is enabled by default.

PiperOrigin-RevId: 705792463
2024-12-13 01:22:27 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
jax authors
182e532675 Merge pull request #25114 from jedborovik:add-optimization-effort-flags
PiperOrigin-RevId: 702892538
2024-12-04 16:04:16 -08:00
Yash Katariya
a735bf83e5 Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
2024-12-04 14:04:25 -08:00
Jed Borovik
c65ce4b093
Merge branch 'main' into add-optimization-effort-flags 2024-11-27 14:08:10 -05:00
Yash Katariya
0d2dfea4b1 Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
2024-11-26 20:01:04 -08:00
labs-code-app[bot]
762301fc5d Add exec_time_optimization_effort and memory_fitting_effort flags.
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
2024-11-26 13:57:47 +00:00
Yash Katariya
c35f8b22c1 Add abstract mesh context manager to trace_context in the fallback path too (which will be deleted after jax 0.4.36 release)
PiperOrigin-RevId: 700006186
2024-11-25 09:18:30 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Dougal
d0f17c0c04 Make a direct linearize trace.
This is an alternative to doing JVP followed by partial eval. The linearize
trace has two parent traces, one for the primal computation and one for the
tangent computation. If we make the tangent trace a DynamicJaxprTrace then we
get staged linearization. If we make it the same as the primal trace then we get
primal and tangent computations occurring in step (JVP). This is a neat trick
enabled by stackless which now lives up to its name. With two parent traces we
have a tree of traces not a linked list stack.

Primitive ops can have their own linearization rules but as a fallback we can
derive a linearization rule for a single op using jvp/partial-eval.

For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can
make this the default for linearize/grad. It should help with remat and AD
through state which are awkward to express via partial eval.
2024-11-20 10:03:00 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
cea8176756 Merge pull request #24751 from Stella-S-Yan:feature/default_device_str
PiperOrigin-RevId: 696560063
2024-11-14 10:00:18 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00
Stella-S-Yan
afa518aa0e Allow setting default_device with platform names. 2024-11-11 22:46:57 +00:00
Dan Foreman-Mackey
4a365670f7 Fix pre-commit to run on all files in CI. 2024-11-08 13:47:27 -05:00
Robert Dyro
04f2ef9e93 Adding JAX_LOGGING_LEVEL configuration option 2024-11-05 09:56:46 -08:00
Peter Hawkins
0e8acff5c6 Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
2024-11-05 08:32:25 -08:00
jax authors
a913fbf2fd rollback due to data race
Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4

PiperOrigin-RevId: 693191298
2024-11-04 21:05:33 -08:00
Peter Hawkins
ab47d4687f [JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
2024-11-04 15:39:06 -08:00
jax authors
f5656bcb11 Merge pull request #24510 from dfm:dot-algorithm-config
PiperOrigin-RevId: 691096482
2024-10-29 11:30:38 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
Dan Foreman-Mackey
03854cfce4 Allow dot algorithms in default_matmul_precision config. 2024-10-29 10:48:21 -04:00
Yash Katariya
6c8e56f43f Finish 0.4.35 release by removing dead code
PiperOrigin-RevId: 689396609
2024-10-24 08:45:43 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
Dan Foreman-Mackey
19313f4c0f Fix lint at HEAD. 2024-10-02 16:18:03 -04:00
jax authors
cfb75411c8 Merge pull request #23738 from keshavb96:disable_remat_pass
PiperOrigin-RevId: 681510015
2024-10-02 10:43:20 -07:00
Keshav
8770fb283b set default value to True 2024-09-20 11:48:41 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Loren Maggiore
f75c5c6b2d [jax] config option to disable using a mesh as a context manager.
PiperOrigin-RevId: 676475039
2024-09-19 10:42:41 -07:00
Keshav
caf57495cf use bool_state instead of bool_flag 2024-09-18 14:53:45 -07:00
Parker Schuh
bf2237a102 Flip jax_pmap_no_rank_reduction by default to True.
This changes:
* The performance of array[0] (use array[0:1] instead).
* The shape of jax_array.addressable_shards or jax_array.addressable_data(0) of arrays that come from pmap.

PiperOrigin-RevId: 673564995
2024-09-11 15:41:47 -07:00
Keshav
7c660c4ea0 Squashed commit of the following:
commit 1abe9559d1ba7a6ec4e2081c52ebdf0eef6b5e56
Merge: 1e1cc3e07 1b2ba9d1c
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 10 09:42:04 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 1e1cc3e0733cca77e2f1ee928f96edcf63f673cf
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 10 09:37:22 2024 -0700

    added comment

commit 631c41fcbdbbac864fadd72c984b07801872f218
Merge: b93b52f27 ce3ea109a
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 21 08:54:00 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit b93b52f27aacf7f58eba914a91810b5d0ac06316
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Aug 20 19:00:08 2024 -0700

    remove stray breakpoint

commit 9ee0842ea98557bcdca0ecfd9031a8ea5274e9a4
Merge: 799e359a5 be53ee10b
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 7 18:09:19 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 799e359a522acd1a83dd7868a3a9278e189664f6
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 7 17:31:27 2024 -0700

    added tests and minor changes

    fix

commit c973004493f633526b14a6b5acb3afe50d58c977
Merge: 5900969cc b3924da2a
Author: Keshav <keshavb@nvidia.com>
Date:   Thu Aug 1 11:28:59 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 5900969cc9178bf3629baa49c6a300446bf6d4a9
Author: Keshav <keshavb@nvidia.com>
Date:   Thu Aug 1 11:20:52 2024 -0700

    minor edits

commit a7cc85a1cb8ddd07b783cc538f25c56f5fb78543
Merge: 89b876270 091eba195
Author: Keshav <keshavb@nvidia.com>
Date:   Mon Jul 29 14:17:13 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 89b876270bf5f16dc10c2f8700d69715752ca184
Author: Keshav <keshavb@nvidia.com>
Date:   Mon Jul 29 14:11:39 2024 -0700

    native IR traversal instead of string manipulation

commit 3b161a414d9579c50e1902047dbd45bac840a767
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 20:12:30 2024 -0700

    longer match string and string search optimization

commit 224ee59d2115ec43000105b97bd6e73c40777ab9
Merge: c7664aa61 6a7822a73
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 17:08:29 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit c7664aa61fa9cec55fba9d5ee1d3ffb146a4c2b1
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 17:07:04 2024 -0700

    remove custom partitioning ptr from pre-compiled hlo during cache key computation

linter fixes

more linter fixes

more linter fixes

alternate imports
2024-09-10 17:30:08 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Yash Katariya
a144eb234b Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits
PiperOrigin-RevId: 671507042
2024-09-05 14:16:13 -07:00
Yash Katariya
f1e0741890 Add use_shardy_partitioner to thread local jit state
PiperOrigin-RevId: 670230769
2024-09-02 08:47:06 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Matthew Johnson
670a648b7b add experimental jax.no_tracing context manager 2024-08-23 21:21:55 +00:00
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Matthew Johnson
abfb1ce72d add temporary flag to suppress an error message, to unblock a user 2024-07-31 17:23:47 +00:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
George Necula
d34a6e9ce2 [jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.

PiperOrigin-RevId: 652749891
2024-07-16 02:05:43 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
George Necula
47f1b3de2c [export] Add documentation for debugging and for ensuring compatibility.
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
2024-06-28 08:36:55 +03:00
jax authors
639892bb04 Merge pull request #22123 from mattjj:dynamic-trace-state-simplification
PiperOrigin-RevId: 647129667
2024-06-26 17:21:05 -07:00
Matthew Johnson
275ddad51d tweak dynamic trace state to only depend on level int, not MainTrace 2024-06-26 23:42:49 +00:00