24334 Commits

Author SHA1 Message Date
Hyeontaek Lim
bbaec6ea59 [JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
2024-11-26 13:31:15 -08:00
Yash Katariya
6763fcfb4e Fix a weird interaction with set_local and empty tuples passed to it.
PiperOrigin-RevId: 700392735
2024-11-26 10:50:05 -08:00
Vladimir Belitskiy
e453fa179e Update XLA dependency to use revision
PiperOrigin-RevId: 700373062
2024-11-26 09:48:02 -08:00
jax authors
92e18e6d5c [AutoPGLE] Fix pgle test after removing pjit cache.
PiperOrigin-RevId: 700359385
2024-11-26 08:58:15 -08:00
Ayaka
dc11d402f5 [Pallas TPU] Better error message for lowering sp.broadcast_to_p
`sp.broadcast_to_p` is a GPU-specific primitive, but it mistakenly appears in TPU lowerings. This PR improves the error message to reflect this.

As an example, currently, users will hit this error when doing:

```
def kernel(x_ref, o_ref):
    m, n = 32, 8
    x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], jnp.arange(n, dtype=jnp.int32)[None]))
    o_ref[...] = x
```

PiperOrigin-RevId: 700290975
2024-11-26 04:09:33 -08:00
jax authors
231967fdb5 [AutoPGLE] Explicitly ignore host callback pointers
Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE.

PiperOrigin-RevId: 700289965
2024-11-26 04:06:15 -08:00
Sergei Lebedev
b6566c80b0 [mosaic_gpu] Fixed unbounded recursion in FragmentedArray._pointwise
PiperOrigin-RevId: 700265616
2024-11-26 02:14:54 -08:00
jax authors
16a5607c91 Use xla_extension_version instead of jaxlib_version
PiperOrigin-RevId: 700265297
2024-11-26 02:12:57 -08:00
jax authors
024e331441 Merge pull request #25084 from ROCm:ci_rocm_version
PiperOrigin-RevId: 700241231
2024-11-26 00:35:01 -08:00
Christos Perivolaropoulos
f828f2d7d0 [mgpu] Pointwise min
PiperOrigin-RevId: 700175724
2024-11-25 19:13:51 -08:00
Yash Katariya
627debc78b Create a null_mesh_context internal context manager to handle null contexts properly.
PiperOrigin-RevId: 700167406
2024-11-25 18:32:05 -08:00
Yash Katariya
59e13f8114 Add sharding argument to reshape since it also takes a shape argument for the output shape
PiperOrigin-RevId: 700163883
2024-11-25 18:16:08 -08:00
Christos Perivolaropoulos
c5dc980db8 [mgpu/pallas_mgpu] Pointwise tanh support
PiperOrigin-RevId: 700158250
2024-11-25 17:56:11 -08:00
Christos Perivolaropoulos
ef7df1ae7c [pallas_mgpu] Allow trees (eg tuples) to be returned from cond_p expressions.
PiperOrigin-RevId: 700136799
2024-11-25 16:36:43 -08:00
jax authors
ebea4353f8 Update XLA dependency to use revision
7059553f7e.

PiperOrigin-RevId: 700110142
2024-11-25 14:57:20 -08:00
Nitin Srinivasan
f7e9f62537 Add new CI scripts for building JAX artifacts
This commit introduces new CI scripts and environment files for building JAX artifacts. It makes use of the artifact envs inside the "ci/envs/build_artifacts" folder to control the build behavior. For e.g: for building jaxlib, we will need to run `./ci/build_artifacts.sh ./ci/envs/build_artifacts/jaxlib.env` from the JAX GitHub root.

PiperOrigin-RevId: 700104283
2024-11-25 14:37:02 -08:00
jax authors
788f4935ec Merge pull request #25041 from dfm:ffi-example-refactor
PiperOrigin-RevId: 700093685
2024-11-25 14:04:14 -08:00
Nitin Srinivasan
6761512658 Re-factor build CLI to a subcommand based approach
This commit reworks the JAX build CLI to a subcommand based approach where CLI use cases are now defined as subcommands. Two subcommands are defined: build and requirements_update. "build" is to be used when wanting to build a JAX wheel package. "requirements_update" is to be used when wanting to update the requirements_lock.txt files. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific arguments that apply to its respective build process. In addition, arguments are separated into groups to achieve a cleaner separation and improves the readability when the CLI subcommands are run with `--help`. It also makes it clear as to which parts of the build they affect. E.g: CUDA arguments only apply to CUDA builds, ROCM arguments only apply to ROCM builds, etc. This reduces the complexity and the potential for errors during the build process. Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

There is also a transition from using `subprocess.check_output` to `asyncio.create_subprocess_shell` for executing the build commands which allows for streaming logs and helps in showing the build progress in real time.

Usage:
* Building `jaxlib`:
```
python build/build.py build --wheels=jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py build --wheels=jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building multiple packages:
```
python build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py build --wheels=jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py build --wheels=jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```

For more details on each argument and to see available options, run:
```
python build/build.py build --help
```
or
```
python build/build.py requirements_update --help
```

PiperOrigin-RevId: 700075411
2024-11-25 13:03:04 -08:00
Bill Varcho
f22bafac31 [SDY] remove TODO for enabling Layouts for Shardy post cl/697715276.
PiperOrigin-RevId: 700053383
2024-11-25 11:45:00 -08:00
jax authors
95029abc18 drop compute capability check
PiperOrigin-RevId: 700052796
2024-11-25 11:42:56 -08:00
Justin Fu
107bc96c29 [Mosaic GPU] Support batch dimensions in FA3 MGPU kernel.
PiperOrigin-RevId: 700052530
2024-11-25 11:42:39 -08:00
Yash Katariya
deab6fbd80 Remove _pjit_lower_cached cache. We can simplify the caching of jit as we have downstream caches and a cpp cache too.
If you drop out of cpp cache, things are going to be slow anyways.

PiperOrigin-RevId: 700052522
2024-11-25 11:40:50 -08:00
Bill Varcho
066859e62f [SDY] Enable test_pjit_array_multi_input_multi_output since Shardy conflict resolution is now complete.
PiperOrigin-RevId: 700042542
2024-11-25 11:10:00 -08:00
Bill Varcho
bb1024f3fd [SDY] enable cpu_shardy for JAX shard_alike test.
PiperOrigin-RevId: 700029576
2024-11-25 10:33:17 -08:00
Yash Katariya
c35f8b22c1 Add abstract mesh context manager to trace_context in the fallback path too (which will be deleted after jax 0.4.36 release)
PiperOrigin-RevId: 700006186
2024-11-25 09:18:30 -08:00
jax authors
aa05dc0b5c Automated Code Change
PiperOrigin-RevId: 699991540
2024-11-25 08:31:06 -08:00
Ruturaj4
e8934b95eb [ROCm] Add rocm version information 2024-11-25 10:21:48 -06:00
Adam Paszke
914600a063 [Mosaic GPU] Simplify logic for pointwise splat operands
The previous version of the code was too complicated and failed to account
for the fact that in an op that broadcasts there does not necessarily exist
and operand that has the output shape.

Reading through the code now, it's a bit weird that we allow implicit
broadcasting of operands with splat layouts, but not any other operands.
But I guess that's a thing to implement later.

PiperOrigin-RevId: 699983045
2024-11-25 08:00:21 -08:00
Dan Foreman-Mackey
84a9cba85b Refactor FFI examples to consolidate several examples into one submodule. 2024-11-25 09:08:20 -05:00
Peter Buchlovsky
69e3f0d37d [pallas:mosaic_gpu] Add test for FragmentedArray.bitcast.
PiperOrigin-RevId: 699919048
2024-11-25 03:30:57 -08:00
jax authors
b372ce4b1a Update XLA dependency to use revision
40d457a268.

PiperOrigin-RevId: 699768724
2024-11-24 15:00:10 -08:00
jax authors
4d8751bff4 Update XLA dependency to use revision
90af2896ab.

PiperOrigin-RevId: 699545393
2024-11-23 15:31:36 -08:00
jax authors
e53ff2cbfc [Mosaic][Easy] - Wire up kernel names to MLIR dump
PiperOrigin-RevId: 699408419
2024-11-22 23:39:38 -08:00
jax authors
b259fde541 Fix member access to xla backend. The correct member is client instead of backend
PiperOrigin-RevId: 699338495
2024-11-22 17:39:44 -08:00
Yash Katariya
8699f5d970 When host local inputs on all hosts are the same, use _DeferredShardArg to do the transfers instead of jit to avoid blocking.
PiperOrigin-RevId: 699336402
2024-11-22 17:32:49 -08:00
jax authors
030ee4a1b2 Merge pull request #25070 from jax-ml:pjit-lin-rule
PiperOrigin-RevId: 699304829
2024-11-22 15:25:58 -08:00
jax authors
9f6dbef3dc Update XLA dependency to use revision
0564969ba3.

PiperOrigin-RevId: 699295115
2024-11-22 14:50:57 -08:00
Dougal
b1d1dcf607 Add linearization rule for pjit_p 2024-11-22 14:24:46 -08:00
Yash Katariya
21f8885a9e [sharding_in_types] Make argmax and argmin work with sharding_in_types. This also requires adding reduce_p sharding rule
PiperOrigin-RevId: 699244204
2024-11-22 12:00:22 -08:00
Yash Katariya
7635605262 Use with_spec where possible to clean up the code a bit
PiperOrigin-RevId: 699226058
2024-11-22 11:01:58 -08:00
Keith Rush
c0811c9dff Adds coverage for spmd-axisname-filtering in shard_map transpose.
PiperOrigin-RevId: 699193349
2024-11-22 09:14:29 -08:00
Nitin Srinivasan
34a2f0ca4a Add a jaxlib at head build to the cloud-tpu-ci-nightly workflow
This will allow us to test TPU compatibility with jaxlib at head. Also, enable v4 runners as they are now online.

PiperOrigin-RevId: 699155667
2024-11-22 06:45:36 -08:00
Justin Fu
73fa0f48cb [Pallas] Deprecate dictionary compiler_params in favor of dataclass.
PiperOrigin-RevId: 699057658
2024-11-21 23:34:32 -08:00
Yash Katariya
355589f32b [sharding_in_types] Add scan support to sharding_in_types. There are a couple of changes here
* Set abstract_mesh context manager during pjit_p.bind at the top level too since scan builds jaxpr during it's lowering in `_scan_impl` (do the same for AOT path)

* Set the abstract mesh only once if it's not set. Don't override an already set context. This means that only top level jit sets the context manager.

* Add dynamic_slice and dynamic_update_slice sharding rules since scan calls into them.

* scan only allows `xs` where the 0th dim is full replicated i.e. None.

PiperOrigin-RevId: 699014167
2024-11-21 20:13:23 -08:00
jax authors
3d79df2464 Merge pull request #25048 from jax-ml:linearization-rule-signature
PiperOrigin-RevId: 699007033
2024-11-21 19:47:26 -08:00
Dougal
170718c8d4 Change signature of linearization rules.
Give the rule the nonzero tangent pattern up-front. This is needed to make a
linearization rule for pjit_p. Also make the rules return the nonzero tangents
out, an explicit residual, and a closed tangent function. Add a rule for sin_p
to test it out. We still need to figure out how to avoid having to precompute
`cos(x)`. I think we need to update our backward pass code.
2024-11-21 19:03:42 -08:00
Justin Fu
344d0d998d [Pallas] Add readme page for debugging tips.
PiperOrigin-RevId: 698939951
2024-11-21 15:43:23 -08:00
jax authors
26443bbd66 Update XLA dependency to use revision
85360d67ff.

PiperOrigin-RevId: 698915433
2024-11-21 14:26:26 -08:00
Jevin Jiang
f899d51535 [Mosaic TPU] Fold sublane offset to indices when storing to untiled ref.
This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]).

PiperOrigin-RevId: 698896373
2024-11-21 13:29:06 -08:00
Kyle Lucke
f3e7e6829a Remove unneeded dependency from rocm_plugin_extension.
PiperOrigin-RevId: 698872849
2024-11-21 12:18:11 -08:00