272 Commits

Author SHA1 Message Date
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
George Necula
abcaec7081 [better_errors] Add debug info to the Jaxprs formed for AD
Following #26078 , we add debug info to more calls of lu.wrap_init.
2025-02-05 19:21:02 +02:00
Justin Fu
b01111d96c Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
2025-01-29 15:01:43 -08:00
Peter Hawkins
95cb0eb1c9 Optimize JaxprEqnContext context manager.
* Implement the context manager as a context manager class, rather than using @contextlib.contextmanager. It turns out the contextlib contextmanagers are rather slow.
* Fuse the four child context managers into a single context manager. This saves us a bunch of allocations.
* While we are here, also simplify the xla_metadata context manager to avoid its dual representation of the current metadata.

PiperOrigin-RevId: 719918121
2025-01-26 12:08:44 -08:00
Peter Hawkins
776327919f Optimize implementation of the compute_on context manager.
* We don't need to keep a separate thread-local stack of objects: the config state already has a thread local.
* We don't need to keep an explicit stack of contexts at all: we can maintain it in the context manager frames.
* When checking for incompatible nested compute_ons, we can just check the current state: no need to look higher in the stack!

PiperOrigin-RevId: 719892989
2025-01-26 09:24:33 -08:00
Parker Schuh
3864512b72 Move transfer python bindings into jax.
PiperOrigin-RevId: 719082208
2025-01-23 17:52:57 -08:00
Dan Foreman-Mackey
e3b3b913f7 Add an experimental interface for customizing DCE behavior.
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.

In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.

PiperOrigin-RevId: 718950828
2025-01-23 11:38:47 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
Sharad Vikram
0ac63157f5 [Pallas TPU] Add helpers file with copy_ref function
PiperOrigin-RevId: 716030813
2025-01-15 18:34:58 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Jake VanderPlas
640cb009f1 bazel visibility change
PiperOrigin-RevId: 713488528
2025-01-08 18:34:10 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
John QiangZhang
c39e38fe5a bazel: export serialization.fbs for downstream usage
PiperOrigin-RevId: 712587802
2025-01-06 10:57:35 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
jax authors
f65ecedde7 Merge pull request #25593 from mattjj:ref-errors-4
PiperOrigin-RevId: 707733777
2024-12-18 18:23:33 -08:00
Matthew Johnson
e52856261f add mutable array ref error checks to scan 2024-12-19 01:33:39 +00:00
Yash Katariya
b5e4fd161d [sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
2024-12-10 18:03:21 -08:00
Bixia Zheng
2a4a0e8d6f [jax:custom_partitioning] Implement SdyShardingRule to support
Shardy custom_partitioning.

The parsing of the sharding rule string very closely follows how einops parses
their rules in einops/parsing.py.

When a SdyShardingRule object is constructed, we check the syntax of the Einsum
like notation string and its consistency with the user provided factor_sizes,
and report errors accordingly. This is done during f.def_partition.

When SdyShardingRule.build is called, during JAX to MLIR lowering, we check
the consistency between the Einsum like notation string, the factor_sizes
and the MLIR operation, and report errors accordingly.

PiperOrigin-RevId: 703187962
2024-12-05 11:33:23 -08:00
Enrique Piqueras
8c521547b7
Add experimental JAX roofline API. 2024-11-27 14:38:57 -08:00
Hyeontaek Lim
bbaec6ea59 [JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
2024-11-26 13:31:15 -08:00
jax authors
231967fdb5 [AutoPGLE] Explicitly ignore host callback pointers
Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE.

PiperOrigin-RevId: 700289965
2024-11-26 04:06:15 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Jake VanderPlas
8c71d1ad6d Make deprecated jax.experimental.array_api module visibility internal-only
This is in preparation for the module to be removed.

PiperOrigin-RevId: 698215225
2024-11-19 18:33:07 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00
Sergei Lebedev
d304025a41 [mosaic_gpu] The profiler now uses FFI calls for creating events and computing elapsed time
PiperOrigin-RevId: 695798787
2024-11-12 11:01:59 -08:00
Adam Paszke
8b21614973 [Pallas:MGPU] Add FlashAttention3 as an example
PiperOrigin-RevId: 690977852
2024-10-29 05:21:43 -07:00
Hyeontaek Lim
77797f434d [JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00
Sergei Lebedev
dfa6fcd56b [pallas:mosaic_gpu] Extracted a basic emit_pipeline API from the in kernel pipelining test
PiperOrigin-RevId: 690619853
2024-10-28 08:25:47 -07:00
Sergei Lebedev
5a2128e44b [pallas] Removed deprecated aliases to CostEstimate and run_scoped
PiperOrigin-RevId: 689871787
2024-10-25 12:16:58 -07:00
Sergei Lebedev
06c08bd118 Renamed :pallas_gpu to :pallas_triton
:pallas_gpu is now an umbrella target for Triton and (hopefully soon)
Mosaic GPU backends.

PiperOrigin-RevId: 683145270
2024-10-07 05:44:00 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00
Tom Natan
ed5ba633d4 Reverts 6cf09f8c24c67ff650b95d174501fff3cb59db0d
PiperOrigin-RevId: 682440543
2024-10-04 13:56:27 -07:00
Justin Fu
350afaa7b6 [Pallas] Clean up lowering exceptions.
PiperOrigin-RevId: 681073628
2024-10-01 10:26:40 -07:00
Tom Natan
6cf09f8c24 Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3
PiperOrigin-RevId: 678756266
2024-09-25 10:33:53 -07:00
Tom Natan
eff00cc449 [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See https://github.com/openxla/stablehlo/pull/2259

PiperOrigin-RevId: 678649138
2024-09-25 04:53:11 -07:00
jax authors
9465d427c0 Merge pull request #22302 from yhtang:add-k8s-initialize
PiperOrigin-RevId: 676962862
2024-09-20 14:03:50 -07:00
Yu-Hang Tang
c88c3aecae add k8s cluster environment 2024-09-20 17:26:53 +00:00
Jevin Jiang
839ce9a11d [Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```

Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:53:29 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Yash Katariya
a144eb234b Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits
PiperOrigin-RevId: 671507042
2024-09-05 14:16:13 -07:00
Justin Fu
2d74c6aa05 Add TritonCompilerParams for specifying compiler arguments instead of a dict.
PiperOrigin-RevId: 671081069
2024-09-04 13:32:25 -07:00
Yash Katariya
252caebce3 Create jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None) API to make it easier to create a mesh and reduce a ton of boilerplate.
`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.

PiperOrigin-RevId: 670707995
2024-09-03 14:32:03 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Jieying Luo
a3ae5e18d3 Remove build_cuda_plugin_from_source flag which is no longe used.
751b5742fd

PiperOrigin-RevId: 661370449
2024-08-09 12:54:14 -07:00
Jake VanderPlas
48c5fab023 [array api] fix deprecation to support old import pattern 2024-08-01 14:38:59 -07:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
Christos Perivolaropoulos
80a193d5db [pallas] Use the same primitive run_scoped_p for moth mosaic and mosaic_gpu
PiperOrigin-RevId: 655751205
2024-07-24 17:14:30 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00