2140 Commits

Author SHA1 Message Date
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Sergei Lebedev
46e65b5982 [pallas] Added API docs for Triton and Mosaic GPU backends
I've left the TPU backend docs a stub for now. Hopefully, someone working
on Pallas TPU can fill them in later.
2024-10-10 12:27:53 +01:00
Dan Foreman-Mackey
1f0a04a4fc Add jax.make_mesh to API docs. 2024-10-09 13:55:43 -04:00
Sergei Lebedev
76d5938062 [pallas] Added MemoryRef and run_scoped to the API docs
PiperOrigin-RevId: 683349061
2024-10-07 15:35:09 -07:00
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
Sergei Lebedev
41791ac756 [pallas] Removed support for the deprecated pl.BlockSpec argument order
PiperOrigin-RevId: 682036180
2024-10-03 14:39:58 -07:00
Ayaka
e79d77aa47 [Pallas] [Docs] Replace full urls with label-based cross references
This PR uses the same method to add cross references as the previous PR https://github.com/jax-ml/jax/pull/23889.

---

The content below is for future references.

#### Useful commands

Build documentation:

```sh
sphinx-build -b html -D nb_execution_mode=off docs docs/build/html -j auto
```

Create a label in *.md:

```md
(pallas_block_specs_by_example)=
```

Create a label in *.rst:

```rst
.. _pallas_tpu_noteworthy_properties:
```

Reference a label in *.md:

```md
{ref}`pallas_block_specs_by_example`
```

Sync changes from *.md to *.ipynb:

```sh
jupytext --sync docs/pallas/tpu/distributed.md
```

PiperOrigin-RevId: 682034607
2024-10-03 14:35:51 -07:00
Jake VanderPlas
635e29a0b9 Implement jax.numpy.spacing
Somehow we've missed this numpy API up until now.
2024-10-03 10:40:39 -07:00
Ayaka
ad78147183 [Docs] Add docstring for RoundingMethod
Currently, the class only has "An enumeration." as the docstring when viewing the documentation, which is unhelpful for users. This PR adds class members, detailed descriptions and cross-references to the docstring to make it beautiful and informative.

PiperOrigin-RevId: 681866947
2024-10-03 07:23:22 -07:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Ikko Eltociear Ashimine
a7c6935994
docs: update Custom_Operation_for_GPUs.md
implementaion -> implementation
2024-10-02 12:57:45 +09:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Ayaka
ab4590ce0a [Pallas TPU] Add a note in the Pallas Quickstart documentation about the instructions of running the existing example on TPU
This fixes https://github.com/jax-ml/jax/issues/22817

This changes is originally proposed by @justinjfu in the comments of the above issue.

This PR is related to https://github.com/jax-ml/jax/pull/23885.

PiperOrigin-RevId: 679487218
2024-09-27 01:33:08 -07:00
jax authors
6f7ad641d7 Merge pull request #23940 from jakevdp:jacobian-doc
PiperOrigin-RevId: 679203936
2024-09-26 10:34:25 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Jake VanderPlas
cf51ee7ef0 Improve documentation for jax.jacobian 2024-09-26 05:09:47 -07:00
Jacob Burnim
a1f2edc968 Fix make_remote_async_copy -> make_async_remote_copy in async doc. 2024-09-25 13:39:39 -07:00
jax authors
f126705dd0 Merge pull request #23914 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 678752363
2024-09-25 10:26:32 -07:00
rajasekharporeddy
13774d1382 Fix Typos 2024-09-25 21:26:05 +05:30
Dan Foreman-Mackey
bc1e1a0220 Add support for setting a dot product "algorithm" for lax.dot_general.
The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.

This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec.

Transposition is supported using the `transpose_algorithm` argument.

PiperOrigin-RevId: 678672686
2024-09-25 06:17:09 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
jax authors
6e116491c1 Add --use_cuda_nvcc flag to enable or disable compilation of CUDA code using NVCC.
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).

Mark `--use_clang` flag as deprecated.

Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.

PiperOrigin-RevId: 678332548
2024-09-24 11:37:00 -07:00
Yash Katariya
a99ea73336 Use jax.make_array_from_process_local_data API in distributed data loading doc
PiperOrigin-RevId: 677973689
2024-09-23 16:03:34 -07:00
Dongseong Hwang
e4091a6752 Fix another errata in block-sparse kernel tutorial.
PiperOrigin-RevId: 677952796
2024-09-23 15:04:29 -07:00
Dongseong Hwang
91f16419bb Fix errata in block-sparse kernel tutorial.
Correct M//blk_M to N//blk_N. It was ok because both values happen to be same.
In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'.

PiperOrigin-RevId: 677817478
2024-09-23 09:07:28 -07:00
Frederic Bastien
a159c0f417 Document jax.checkpoint policies. 2024-09-22 16:05:20 -04:00
8bitmp3
0cf040c9a1 Add/update JAX Advanced Tutorials docs, ToC structure 2024-09-20 23:06:54 +00:00
jax authors
886aa944fa Merge pull request #23707 from jakevdp:stop-gradient-doc
PiperOrigin-RevId: 676876785
2024-09-20 09:48:08 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
jax authors
5f044a67d8 Merge pull request #23674 from justinjfu:pallas_prefetch_docs
PiperOrigin-RevId: 676525366
2024-09-19 12:49:28 -07:00
Justin Fu
4bce4f6452 [Pallas] Add block-sparse kernel tutorial 2024-09-19 12:23:03 -07:00
Dan Foreman-Mackey
73c38cb700 Add a note to the developer docs making it clear that clang is the only
toolchain that is actively supported for source compilation.

As discussed in https://github.com/google/jax/issues/23687
2024-09-18 10:01:13 -04:00
Sergei Lebedev
e90336947a Pulled scratch_shapes into GridSpec
It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton.

PiperOrigin-RevId: 675950199
2024-09-18 05:26:21 -07:00
Sergei Lebedev
b904599b98 pl.debug_print no longer restricts values to be scalars
This allows printing arrays on Triton and soon on Mosaic GPU.

PiperOrigin-RevId: 675935666
2024-09-18 04:24:09 -07:00
Jake VanderPlas
3c37d4f20e Improve documentation for jax.lax.stop_gradient 2024-09-17 15:58:14 -07:00
Sharad Vikram
9d3762bd47 [Pallas] Add design note for async ops on TPU 2024-09-17 12:45:29 -07:00
Yash Katariya
8ab66c8103 Fix the TPU and GPU nightly install instructions.
PiperOrigin-RevId: 675233702
2024-09-16 11:46:58 -07:00
jax authors
dfa4e2413c Merge pull request #23643 from TomAugspurger:fix/doc-typos
PiperOrigin-RevId: 675133675
2024-09-16 07:07:26 -07:00
enerrio
b8d135aa05 fix small typos in docs 2024-09-14 13:53:19 -07:00
Tom Augspurger
fcc8c3759d Fixed func ref in shared-computation 2024-09-14 15:22:12 -05:00
George Necula
67980d6af4 [export] Improve the forward compatibility documentation
Update the documentation to use the `LoweringRuleContext.is_forward_compat`
helper function.
2024-09-13 08:38:49 +03:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
jax authors
7c8508e593 Add link to XLA documentation for building JAX with CUDA from sources.
PiperOrigin-RevId: 673510767
2024-09-11 13:20:46 -07:00
jax authors
02ab741155 Merge pull request #23478 from dfm:ffi-release-notes
PiperOrigin-RevId: 673065825
2024-09-10 12:39:49 -07:00
Peter Hawkins
b975592478 Change nightly install commands to include all packages.
pip doesn't update transitive dependencies, and we probably want the latest versions of everything when installing a nightly.
2024-09-09 14:38:41 -04:00
jax authors
c1bac25a66 Merge pull request #23424 from jakevdp:dep-doc
PiperOrigin-RevId: 672580398
2024-09-09 10:00:51 -07:00
jax authors
b0fc2759b1 Merge pull request #23512 from jakevdp:permute-dims-doc
PiperOrigin-RevId: 672574120
2024-09-09 09:42:52 -07:00
Jake VanderPlas
ab29fee763 Add array_api intersphinx & document jnp.permute_dims 2024-09-09 08:23:38 -07:00