2139 Commits

Author SHA1 Message Date
Justin Fu
dcd186f552 [Pallas] Add pallas distributed computation tutorial 2024-08-08 17:34:35 -07:00
Yash Katariya
7f8a4c84d3 Remove PositionalSharding from distributed array doc 2024-08-07 21:25:24 -07:00
George Necula
64eb8e9639 [pallas] Add a warning message about experimental and incomplete status 2024-08-07 08:38:56 +03:00
Jake VanderPlas
4f8c5a335d CI: pin sphinx to avoid build errors on 8.0 2024-08-06 09:16:41 -07:00
jax authors
44a8c98912 Merge pull request #22141 from dfm:update-cuda-call-example-to-ffi-call
PiperOrigin-RevId: 659542133
2024-08-05 07:09:12 -07:00
George Necula
9b35b760ce [pallas] Enable check for GPU lowering that tensor sizes are power of 2
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.

PiperOrigin-RevId: 659483450
2024-08-05 02:34:21 -07:00
Dan Foreman-Mackey
5474e0e081 Update CUDA custom call example code to use ffi_call.
Following up on #21925, we can update the example code in
`docs/cuda_custom_call` to use `ffi_call` instead of manually
registering `core.Primitive`s. This removes quite a bit of boilerplate
and doesn't require direct use of MLIR.
2024-08-02 10:15:10 -04:00
jax authors
aa9e1e42a1 Merge pull request #22095 from dfm:ffi-call-tutorial
PiperOrigin-RevId: 658523039
2024-08-01 13:43:34 -07:00
Dan Foreman-Mackey
0b4800a193 Add ffi_call tutorial
Building on #21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
2024-08-01 15:36:32 -04:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
George Necula
43163ff2e3 [pallas] Add error message for block_shapes of rank less than 1.
PiperOrigin-RevId: 658424421
2024-08-01 09:15:07 -07:00
jax authors
d696813b1f Merge pull request #22746 from gnecula:pallas_consts
PiperOrigin-RevId: 658050734
2024-07-31 10:13:34 -07:00
jax authors
a207fe9b77 Export KeyPath and related types to jax.tree_util
These types lie on the APIs in `jax.tree_util`, so it makes sense to export them.

PiperOrigin-RevId: 657987755
2024-07-31 06:41:33 -07:00
George Necula
987bf33e85 [pallas] Disallow capturing of consts by kernel functions.
Previously this was allowed, but until recently (#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
2024-07-31 09:06:29 +02:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Peter Hawkins
d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021 [pallas] Add lowering errors for block shapes that are not supported.
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.

PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
jax authors
8ed94bcfb6 [shard_map docs]: Fix doc typos
PiperOrigin-RevId: 656265100
2024-07-25 23:29:55 -07:00
Yash Katariya
2eb1888c98 Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative
PiperOrigin-RevId: 656176702
2024-07-25 18:32:37 -07:00
George Necula
e6066e52a2 [pallas] Stop exporting jax.experimental.pallas.pallas
This was giving access to too many internal APIs.

PiperOrigin-RevId: 655887765
2024-07-25 03:03:58 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
jax authors
696e73042b Merge pull request #22558 from superbobry:pallas
PiperOrigin-RevId: 655138162
2024-07-23 06:12:46 -07:00
jax authors
ce5f9a6da9 Merge pull request #22530 from superbobry:maint
PiperOrigin-RevId: 654881710
2024-07-22 13:46:58 -07:00
Dan Foreman-Mackey
705eed3388 Fixing dtype canonicalization in sharp edges tutorial.
As reported in https://github.com/google/jax/issues/22493, the sharp
edges tutorial doesn't seem to actually enable x64 when it says it does.

Fixes https://github.com/google/jax/issues/22493
2024-07-22 15:02:02 -04:00
Sergei Lebedev
7157839853 Fixed pl.BlockSpec argument ordering in the Pallas TPU matmul tutorial 2024-07-22 12:28:40 +01:00
Sergei Lebedev
4fa93cff35 Documented a few more Pallas APIs and added them to the API docs 2024-07-21 22:32:51 +01:00
jax authors
e9c40467d7 Merge pull request #22526 from sharadmv:pallas-matmul-docs
PiperOrigin-RevId: 654112294
2024-07-19 13:44:48 -07:00
Dan Foreman-Mackey
b308c64936 Export jaxlib.xla_client.register_custom_call_target as jax.extend.ffi.register_ffi_target.
This means that users of the FFI interface won't need to directly
interact with `jaxlib.xla_client` at all.

I've expanded the doctring a little and changed one default: the default
`api_version` is `1` instead of `0` to be consistent with the new name.
2024-07-19 08:12:25 -04:00
Sharad Vikram
b8c4dde429 Copy improvements
Co-authored-by: Jacob Austin<jaaustin@google.com>
2024-07-18 22:59:00 -07:00
Sharad Vikram
c504fa6a4a Minor matmul docs improvement 2024-07-18 15:21:18 -07:00
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
jax authors
85c30d2a86 Merge pull request #20021 from sharadmv:pallas-matmul-docs
PiperOrigin-RevId: 653070524
2024-07-16 20:16:04 -07:00
Sharad Vikram
10e09af7a0 Address changes 2024-07-16 19:25:32 -07:00
Sharad Vikram
ff62d5e229 Address changes 2024-07-16 19:24:56 -07:00
Justin Fu
6ba889c01c [Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
2024-07-16 18:13:02 -07:00
Sharad Vikram
39ec5dacb4 [Pallas TPU] Add matrix multiplication tutorial 2024-07-16 18:12:19 -07:00
jax authors
5ddec63a47 Merge pull request #22441 from gnecula:test_clean_hypothesis
PiperOrigin-RevId: 652919414
2024-07-16 11:32:46 -07:00
kaixih
0d387e0839 Update jax doc sdpa 2024-07-15 17:30:54 +00:00
jax authors
2b29a94255 Merge pull request #22375 from jakevdp:mypy-docs
PiperOrigin-RevId: 652511749
2024-07-15 09:52:07 -07:00
George Necula
d3454f374e Add some hypothesis testing utilities and developer documentation.
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
2024-07-15 17:05:32 +02:00
George Necula
be8e83adc1 [docs] Fix docs building error
The checkify APIs were mentioned in the jax.experimental.rst and also
in jax.experimental.checkify.rst.
2024-07-15 15:42:33 +01:00
jax authors
f60643801d Merge pull request #22370 from gnecula:pallas_unblocked
PiperOrigin-RevId: 651770174
2024-07-12 07:41:38 -07:00
George Necula
7c059d4630 [pallas] Document the indexing_mode=Unblocked()
In the process discovered that the padding in the interpreter
mode was with 0s. I changed it to NaN/minint to match the
padding for the blocked mode.
2024-07-12 12:39:10 +03:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
George Necula
ea548e7c86 [pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
2024-07-10 19:16:53 +03:00