1970 Commits

Author SHA1 Message Date
Peter Hawkins
e4f3f8f064 Use libtpu releases rather than libtpu-nightly for jax[tpu].
PiperOrigin-RevId: 688632409
2024-10-22 11:47:07 -07:00
jax authors
a2e4aff897 Merge pull request #24425 from dfm:rename-vmap-methods
PiperOrigin-RevId: 688547393
2024-10-22 07:51:29 -07:00
Hernan Moraldo
5d3cac6603 Fix documentation.
PiperOrigin-RevId: 688293390
2024-10-21 15:29:59 -07:00
jax authors
11eeff072f Merge pull request #22410 from garymm:patch-1
PiperOrigin-RevId: 688265373
2024-10-21 14:05:22 -07:00
Justin Fu
0b46a236c1 Update Pallas distributed tutorials with jax.make_mesh 2024-10-21 12:49:56 -07:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
Gary Miguel
dc908b4843 Update installation instructions
Apple GPUs and Mac x86_64 is a non-existent combination.
Mac x86_64 with AMD GPU is supported.

It's a bit of a confusing situation so hard to summarize, but hopefully this is more accurate and less confusing

Fixes: #24408
2024-10-21 10:20:09 -07:00
Dan Foreman-Mackey
0b651f0f45 Make ffi_call return a callable 2024-10-21 12:16:57 -04:00
Yash Katariya
ca2d1584f8 Remove mesh_utils.create_device_mesh from docs
PiperOrigin-RevId: 687695419
2024-10-19 15:48:42 -07:00
jax authors
919f7c8684 Merge pull request #24345 from phu0ngng:cuda_custom_call
PiperOrigin-RevId: 687034466
2024-10-17 13:57:15 -07:00
Sergei Lebedev
de7beb91a7 [pallas:mosaic_gpu] Added layout_cast
PiperOrigin-RevId: 686917796
2024-10-17 08:08:05 -07:00
jax authors
3bdc57dd29 Merge pull request #24300 from ROCm:ci_rocm_readme
PiperOrigin-RevId: 686872994
2024-10-17 05:21:13 -07:00
George Necula
9aa79bffba [export] Fix github links in the export documentation
Reflects the repo change google/jax -> jax-ml/jax.
Also changes the error message to put the link to the documentation
in a more visible place.
2024-10-17 08:30:28 +01:00
Jake VanderPlas
e1f280c843 CI: enable additional ruff formatting checks 2024-10-16 16:09:54 -07:00
Ruturaj4
3c3b08dfd6 [ROCm] Fix README.md to update AMD JAX installation instructions 2024-10-16 17:15:32 -05:00
jax authors
089e4aa904 Merge pull request #24341 from phu0ngng:cuda_graph_ex
PiperOrigin-RevId: 686577115
2024-10-16 11:23:28 -07:00
jax authors
ead1c05ada Merge pull request #23831 from nouiz:doc_policies
PiperOrigin-RevId: 686576725
2024-10-16 11:21:41 -07:00
Phuong Nguyen
d4bbb4fd84 added cmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:37:49 -07:00
Phuong Nguyen
82113cd047 rm CmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:27:09 -07:00
Phuong Nguyen
f3775aa233 added cudaGraph traits + use register_ffi_target()
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:01:20 -07:00
Sergei Lebedev
4c0d82824f [pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
PiperOrigin-RevId: 686451398
2024-10-16 04:05:36 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Praveen Batra
3a3190fbce Fix typo in Pallas TPU matmul doc. I think the logical layout of the input array is non-transposed, rather than transposed?
PiperOrigin-RevId: 686151692
2024-10-15 10:23:39 -07:00
Yash Katariya
a2973be051 Don't add mhlo.layout_mode = "default" since that is the default even in PJRT and will help reduce cruft in the IR
PiperOrigin-RevId: 684963359
2024-10-11 14:54:32 -07:00
Justin Fu
cff9e93824 [Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Peter Hawkins
46f0a3eee7 Clone RandomAlgorithm into lax.py, instead of using the version from XLA.
Change in preparation for removing HLO ops from the XLA Python bindings.

In passing, also:
* improve how the documentation of FftType renders.
* remove some stale references to xla_client
* remove the standard_translate rule, which is unused.

PiperOrigin-RevId: 684892102
2024-10-11 11:03:15 -07:00
Frédéric Bastien
e9011940d8
Update docs/gradient-checkpointing.md
Co-authored-by: Matthew Johnson <mattjj@google.com>
2024-10-11 12:33:10 -04:00
jax authors
bc3df0e3f5 Merge pull request #24241 from hawkinsp:autodidax
PiperOrigin-RevId: 684811631
2024-10-11 06:08:32 -07:00
Peter Hawkins
c0efa86bdc Port autodidax to use StableHLO instead of classic HLO. 2024-10-11 08:25:05 -04:00
Sergei Lebedev
acd0e497af [pallas:mosaic_gpu] GPUBlockSpec no longer accepts swizzle
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.

PiperOrigin-RevId: 684797980
2024-10-11 05:11:26 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Sergei Lebedev
46e65b5982 [pallas] Added API docs for Triton and Mosaic GPU backends
I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
2024-10-10 12:27:53 +01:00
Dan Foreman-Mackey
1f0a04a4fc Add jax.make_mesh to API docs. 2024-10-09 13:55:43 -04:00
Sergei Lebedev
76d5938062 [pallas] Added MemoryRef and run_scoped to the API docs
PiperOrigin-RevId: 683349061
2024-10-07 15:35:09 -07:00
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
Sergei Lebedev
41791ac756 [pallas] Removed support for the deprecated pl.BlockSpec argument order
PiperOrigin-RevId: 682036180
2024-10-03 14:39:58 -07:00
Ayaka
e79d77aa47 [Pallas] [Docs] Replace full urls with label-based cross references
This PR uses the same method to add cross references as the previous PR https://github.com/jax-ml/jax/pull/23889.

---

The content below is for future references.

#### Useful commands

Build documentation:

```sh
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
```

Create a label in *.md:

```md
(pallas_block_specs_by_example)=
```

Create a label in *.rst:

```rst
.. _pallas_tpu_noteworthy_properties:
```

Reference a label in *.md:

```md
{ref}`pallas_block_specs_by_example`
```

Sync changes from *.md to *.ipynb:

```sh
jupytext --sync docs/pallas/tpu/distributed.md
```

PiperOrigin-RevId: 682034607
2024-10-03 14:35:51 -07:00
Jake VanderPlas
635e29a0b9 Implement jax.numpy.spacing
Somehow we've missed this numpy API up until now.
2024-10-03 10:40:39 -07:00
Ayaka
ad78147183 [Docs] Add docstring for RoundingMethod
Currently, the class only has "An enumeration." as the docstring when viewing the documentation, which is unhelpful for users. This PR adds class members, detailed descriptions and cross-references to the docstring to make it beautiful and informative.

PiperOrigin-RevId: 681866947
2024-10-03 07:23:22 -07:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Ikko Eltociear Ashimine
a7c6935994
docs: update Custom_Operation_for_GPUs.md
implementaion -> implementation
2024-10-02 12:57:45 +09:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Ayaka
ab4590ce0a [Pallas TPU] Add a note in the Pallas Quickstart documentation about the instructions of running the existing example on TPU
This fixes https://github.com/jax-ml/jax/issues/22817

This changes is originally proposed by @justinjfu in the comments of the above issue.

This PR is related to https://github.com/jax-ml/jax/pull/23885.

PiperOrigin-RevId: 679487218
2024-09-27 01:33:08 -07:00
jax authors
6f7ad641d7 Merge pull request #23940 from jakevdp:jacobian-doc
PiperOrigin-RevId: 679203936
2024-09-26 10:34:25 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Jake VanderPlas
cf51ee7ef0 Improve documentation for jax.jacobian 2024-09-26 05:09:47 -07:00
Jacob Burnim
a1f2edc968 Fix make_remote_async_copy -> make_async_remote_copy in async doc. 2024-09-25 13:39:39 -07:00
jax authors
f126705dd0 Merge pull request #23914 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 678752363
2024-09-25 10:26:32 -07:00