302 Commits

Author SHA1 Message Date
jax authors
0ae613ee48 Makes Effort_02 the default value for memory_fitting_level.
PiperOrigin-RevId: 749159983
2025-04-18 15:09:02 -07:00
Peter Hawkins
96865709b1 Allow the CPU collective implementation to be overridden to None.
PiperOrigin-RevId: 749055960
2025-04-18 09:29:11 -07:00
Parker Schuh
7634230cdc Remove unused jax_spmd_mode flag.
PiperOrigin-RevId: 748792684
2025-04-17 13:32:52 -07:00
Yash Katariya
82215f660e Remove jax_varying_axes_in_types config and rewrite from shard_map_p
PiperOrigin-RevId: 748545142
2025-04-16 22:27:50 -07:00
jax authors
16ffbca542 Merge pull request #27849 from ZacCranko:docfig
PiperOrigin-RevId: 746098316
2025-04-10 11:06:37 -07:00
Zac Cranko
8f9f1aa35a add sphinx extension and placeholder config docs rst
improve layout, information

add dummy import to hopefully fix build issue

parse help text for markdown

whoops didn't mean to do it twice

jax prefix text no longer applies here

two space indents

address definition list ending without blank line error

provide deprecation mechanism

document context managagers if they exist

remove mention of context manager

try and fix formatting

improve formatting, fail to fix warnings

fail to fix bug, make better looking anyway

okay bug was in the parsing of help text to rst, some of which does not parse

wow, found the bug, turns out help strings were not valid rst
2025-04-10 05:55:10 -07:00
Yash Katariya
75e4279e32 Set jax_varying_axes_in_types to True by default.
PiperOrigin-RevId: 745739477
2025-04-09 14:40:31 -07:00
Yash Katariya
8301c304c1 Make changes to shard_map to prepare for setting varying_axes_in_types to True.
The main changes here are:

* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.

* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.

* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.

* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474
2025-04-08 13:47:13 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Sergei Lebedev
12811f08a8 Removed eager_pmap config option
It defaults to True and is not flipped to False by any internal JAX users.

PiperOrigin-RevId: 745067361
2025-04-08 03:30:36 -07:00
Sergei Lebedev
2944e3b2a6 Removed data_dependent_tracing_fallback config option
No internal code needs it any more.

PiperOrigin-RevId: 744870756
2025-04-07 15:27:57 -07:00
Sergei Lebedev
ff00fa91ce Removed unused jax_remat_opt_barrier config option
It defaults to True and is not flipped to False by any internal JAX users.

PiperOrigin-RevId: 744754343
2025-04-07 09:48:57 -07:00
George Necula
1941714d26 [export] Add support for override_lowering_rules to jax.export.
This parameter is already part of the internal API for the
AOT lowering function, here we just expose it to `jax.export`.
2025-04-03 16:13:16 +01:00
Ayaka
7dd78d97fa Add support for configurable error checking categories
PiperOrigin-RevId: 739234594
2025-03-21 10:53:34 -07:00
Ilya Tikhonovskiy
c9ac82c826 [XLA:GPU] Add missing BF16_BF16_F32_X9 matmul option in config.py
Extend the list of possible default algorithms that dot could use.

PiperOrigin-RevId: 736879149
2025-03-14 08:58:59 -07:00
Yash Katariya
e1b62cede1 Raise an error if jax.config.update('jax_num_cpu_devices', val) is called after backend is initialized
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00
Yash Katariya
abcc7fdf4c [sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName] to ShapedArray. Also add jax_varying_axes_in_types config to hide this option under while we develop it.
PiperOrigin-RevId: 736141670
2025-03-12 08:29:16 -07:00
Parker Schuh
b8b690e594 Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
that it more closely matches the CDF for low probably events (less than
2**-nmant).

Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.

PiperOrigin-RevId: 732757388
2025-03-02 19:42:40 -08:00
Emily Fertig
82124da5cd Redefine is_fully_addressable in shardings to support zero local devices for McJAX.
PiperOrigin-RevId: 731526750
2025-02-26 18:17:35 -08:00
Peter Hawkins
256e37af5f Port many uses of contextlib.contextdecorator to explicit context manager classes.
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.

PiperOrigin-RevId: 730939406
2025-02-25 10:31:05 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Yash Katariya
00d8297071 [sharding_in_types] Set the sharding_in_types config to True. This is a purely internal change and shouldn't affect any public APIs.
Some caveats of enabling sharding-in-types by default are that we'll see tracing cache misses which will lead to lowering cache miss and compilation cache misses in the **following cases**: (but persistent compilation cache is not affected so we'll see a cache hit there)

1. Call `jitted_f(arr_ns)` with an array on `NamedSharding` and again `jitted_f(arr_ps)` with an array of same shape and dtype but now with `PositionalSharding`
    * This leads to a tracing cache miss because on the second call, the aval has no sharding since it's PositionalSharding. This applies to calling with any sharding other than NamedSharding

2. `jitted_f = jit(f, in_shardings=ns)`. Call `jitted_f(sharded_arr)` and then on the second call you pass a numpy array `jitted_f(numpy_arr)`
   * This also leads to a cache miss because the avals currently don't look at in_shardings because the semantics of in_shardings is complicated and I don't think we should change the aval based on in_shardings.

**The solution in both cases is make sure to pass the array sharded on the same mesh during both calls to jit.**

PiperOrigin-RevId: 728361493
2025-02-18 14:35:14 -08:00
jax authors
d3850e7fdd Support optimization_level and memory_fitting_level XLA compilation options.
PiperOrigin-RevId: 727070422
2025-02-14 14:46:11 -08:00
Olli Lupton
1bba1ea2e2 Add JAX_COMPILATION_CACHE_EXPECT_PGLE option
This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:

export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
2025-02-06 08:19:45 +00:00
Skye Wanderman-Milne
2aa810fe60 Make JAX_CPU_COLLECTIVES_IMPLEMENTATION and JAX_NUM_CPU_DEVICES env vars
Before, these values could only be specified via jax.config or
flags. This PR makes them proper configs, so they also work as env
vars.
2025-01-28 17:17:56 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Peter Hawkins
95cb0eb1c9 Optimize JaxprEqnContext context manager.
* Implement the context manager as a context manager class, rather than using @contextlib.contextmanager. It turns out the contextlib contextmanagers are rather slow.
* Fuse the four child context managers into a single context manager. This saves us a bunch of allocations.
* While we are here, also simplify the xla_metadata context manager to avoid its dual representation of the current metadata.

PiperOrigin-RevId: 719918121
2025-01-26 12:08:44 -08:00
Peter Hawkins
776327919f Optimize implementation of the compute_on context manager.
* We don't need to keep a separate thread-local stack of objects: the config state already has a thread local.
* We don't need to keep an explicit stack of contexts at all: we can maintain it in the context manager frames.
* When checking for incompatible nested compute_ons, we can just check the current state: no need to look higher in the stack!

PiperOrigin-RevId: 719892989
2025-01-26 09:24:33 -08:00
Yash Katariya
3aa55992fe Remove device_context from trace_context because we don't need it there. We can get compilation cache misses (and tracing/lowering cache hit) naturally without putting concrete devices into trace_context.
PiperOrigin-RevId: 718113413
2025-01-21 16:21:36 -08:00
Yash Katariya
5a068da699 Remove memories flag now that JAX 0.5.0 has been released since it always defaults to True.
PiperOrigin-RevId: 716908015
2025-01-17 22:13:04 -08:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Peter Hawkins
e20523c2e3 Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.

PiperOrigin-RevId: 713411171
2025-01-08 14:09:07 -08:00
Adam Paszke
d2f937e241 Make jax.Arrays a necessary part of the cycle in the GC guard test
Otherwise, the cycle can be broken by clearing the references of the helper
objects, at which points the deallocation of arrays proceeds through regular
reference counting (and does not trigger logs!). I have not verified that
this is what happens, but the test has been mysteriously failing under a
number of configurations and this seems to fix it.

I added a note to the garbage collection guard to clarify that it's not
guaranteed to report all cycles.

PiperOrigin-RevId: 708320953
2024-12-20 07:48:04 -08:00
Yash Katariya
8b734808e8 Remove jax_enable_memories config flag. It defaulted to True for a very long time and it's time to remove the flag.
PiperOrigin-RevId: 707590263
2024-12-18 10:15:45 -08:00
Matthew Johnson
42ac4ca357 ref errors 2024-12-18 07:46:14 +00:00
Yash Katariya
1e22149493 Fix the breakage caused by deleted enable_memories config
PiperOrigin-RevId: 707331603
2024-12-17 18:17:13 -08:00
Yash Katariya
cca9afa28f Delete enable_memories code in C++ since that flag is always True and cannot be turned off now.
PiperOrigin-RevId: 707298305
2024-12-17 16:43:20 -08:00
jax authors
a123d4e39e Remove autotune sharing.
xla_gpu_shard_autotuning can be used now instead and it is enabled by default.

PiperOrigin-RevId: 705792463
2024-12-13 01:22:27 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
jax authors
182e532675 Merge pull request #25114 from jedborovik:add-optimization-effort-flags
PiperOrigin-RevId: 702892538
2024-12-04 16:04:16 -08:00
Yash Katariya
a735bf83e5 Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
2024-12-04 14:04:25 -08:00
Jed Borovik
c65ce4b093
Merge branch 'main' into add-optimization-effort-flags 2024-11-27 14:08:10 -05:00
Yash Katariya
0d2dfea4b1 Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
2024-11-26 20:01:04 -08:00
labs-code-app[bot]
762301fc5d Add exec_time_optimization_effort and memory_fitting_effort flags.
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
2024-11-26 13:57:47 +00:00
Yash Katariya
c35f8b22c1 Add abstract mesh context manager to trace_context in the fallback path too (which will be deleted after jax 0.4.36 release)
PiperOrigin-RevId: 700006186
2024-11-25 09:18:30 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Dougal
d0f17c0c04 Make a direct linearize trace.
This is an alternative to doing JVP followed by partial eval. The linearize
trace has two parent traces, one for the primal computation and one for the
tangent computation. If we make the tangent trace a DynamicJaxprTrace then we
get staged linearization. If we make it the same as the primal trace then we get
primal and tangent computations occurring in step (JVP). This is a neat trick
enabled by stackless which now lives up to its name. With two parent traces we
have a tree of traces not a linked list stack.

Primitive ops can have their own linearization rules but as a fallback we can
derive a linearization rule for a single op using jvp/partial-eval.

For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can
make this the default for linearize/grad. It should help with remat and AD
through state which are awkward to express via partial eval.
2024-11-20 10:03:00 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
cea8176756 Merge pull request #24751 from Stella-S-Yan:feature/default_device_str
PiperOrigin-RevId: 696560063
2024-11-14 10:00:18 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00