Otherwise, the cycle can be broken by clearing the references of the helper
objects, at which points the deallocation of arrays proceeds through regular
reference counting (and does not trigger logs!). I have not verified that
this is what happens, but the test has been mysteriously failing under a
number of configurations and this seems to fix it.
I added a note to the garbage collection guard to clarify that it's not
guaranteed to report all cycles.
PiperOrigin-RevId: 708320953
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.
PiperOrigin-RevId: 700537898
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.
Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.
PiperOrigin-RevId: 698493184
This is an alternative to doing JVP followed by partial eval. The linearize
trace has two parent traces, one for the primal computation and one for the
tangent computation. If we make the tangent trace a DynamicJaxprTrace then we
get staged linearization. If we make it the same as the primal trace then we get
primal and tangent computations occurring in step (JVP). This is a neat trick
enabled by stackless which now lives up to its name. With two parent traces we
have a tree of traces not a linked list stack.
Primitive ops can have their own linearization rules but as a fallback we can
derive a linearization rule for a single op using jvp/partial-eval.
For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can
make this the default for linearize/grad. It should help with remat and AD
through state which are awkward to express via partial eval.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).
This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)
We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.
PiperOrigin-RevId: 697631402
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.
There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.
PiperOrigin-RevId: 693114411
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.
PiperOrigin-RevId: 687003464
This changes:
* The performance of array[0] (use array[0:1] instead).
* The shape of jax_array.addressable_shards or jax_array.addressable_data(0) of arrays that come from pmap.
PiperOrigin-RevId: 673564995
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.
PiperOrigin-RevId: 652749891
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.
For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.
PiperOrigin-RevId: 647657084