1818 Commits

Author SHA1 Message Date
George Necula
e6066e52a2 [pallas] Stop exporting jax.experimental.pallas.pallas
This was giving access to too many internal APIs.

PiperOrigin-RevId: 655887765
2024-07-25 03:03:58 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
George Necula
0d058ce86f Reverts 0e17d26b6d81a6b281f55bdd81a6f0ab45efeafe
PiperOrigin-RevId: 655552768
2024-07-24 07:09:16 -07:00
George Necula
c5871331ba [pallas] Simplify handling of BlockMapping and GridMapping
`BlockSpec`, `GridSpec` and `PrefetchScalarGridSpec` are now simple
dataclasses that just store the parameters passed
from the API. They are then canonicalized and coverted
to `BlockMapping` and `GridMapping`, which contains fewer
optional metadata. In particular, `BlockMapping` is never
`None`. This consolidates the code to preprocess the
block and grid parameters, and simplifies the code downstream.

`grid` now defaults to `()` instead of `None`.

Added more fields to `BlockMapping` (`block_aval`,
`array_shape_dtype`, and `source`). The `source` field
is used in error messages. The `array_shape_dtype` makes
it unnecessary to process BlockMappings zipped with
`in_shapes`. With these fields, we can now add
a `check_invariants` method that is called during testing
or when `config.enable_checks` is true.

Added more fields and a `check_invariants` to `GridMapping`, since it is
such an important data structure.
The new fields are: `index_map_avals`, `index_map_tree` (to encode
the calling convention for the index map functions),
`num_inputs`, `num_outputs`. The latter make it possible to
recover the `in_shapes` and `out_shapes` from the GridMapping.
Previously there was some redundancy of information between
`in_shapes` and `out_shapes`.

Now we do not need the `in_shapes` and `out_shapes` parameters to
`pallas_call_p`, since it already has `grid_mapping`.

Moved some of the logic for handling scalar prefetch and
scratch shapes from `PrefetchScalarGridSpec.get_grid_mapping` to
`GridSpec.get_grid_mapping`, and thus removed code duplication.

Removed some dead code for implementing the interpret mode.

Previous handling of hoisted consts did not account for them in
`in_shapes`. Now, this is fixed since we do not keep track of
`in_shapes` separately.

Renamed `GridMapping.mapped_dims` to `GridMapping.vmapped_dims` to
avoid confusion with the use of mapped in block shapes.

Added test for the calling convention, including dynamic grid dimensions.

There is more work to be done: with the new information in
`GridMapping` it should be possible to clean the code throughout
that extract various parts of the inputs and outputs. This
should be a bunch of local changes, which I will do separately
once I merge this large global change.
2024-07-24 14:48:08 +03:00
jax authors
696e73042b Merge pull request #22558 from superbobry:pallas
PiperOrigin-RevId: 655138162
2024-07-23 06:12:46 -07:00
jax authors
ce5f9a6da9 Merge pull request #22530 from superbobry:maint
PiperOrigin-RevId: 654881710
2024-07-22 13:46:58 -07:00
Dan Foreman-Mackey
705eed3388 Fixing dtype canonicalization in sharp edges tutorial.
As reported in https://github.com/google/jax/issues/22493, the sharp
edges tutorial doesn't seem to actually enable x64 when it says it does.

Fixes https://github.com/google/jax/issues/22493
2024-07-22 15:02:02 -04:00
Sergei Lebedev
7157839853 Fixed pl.BlockSpec argument ordering in the Pallas TPU matmul tutorial 2024-07-22 12:28:40 +01:00
Sergei Lebedev
4fa93cff35 Documented a few more Pallas APIs and added them to the API docs 2024-07-21 22:32:51 +01:00
jax authors
e9c40467d7 Merge pull request #22526 from sharadmv:pallas-matmul-docs
PiperOrigin-RevId: 654112294
2024-07-19 13:44:48 -07:00
Dan Foreman-Mackey
b308c64936 Export jaxlib.xla_client.register_custom_call_target as jax.extend.ffi.register_ffi_target.
This means that users of the FFI interface won't need to directly
interact with `jaxlib.xla_client` at all.

I've expanded the doctring a little and changed one default: the default
`api_version` is `1` instead of `0` to be consistent with the new name.
2024-07-19 08:12:25 -04:00
Sharad Vikram
b8c4dde429 Copy improvements
Co-authored-by: Jacob Austin<jaaustin@google.com>
2024-07-18 22:59:00 -07:00
Sharad Vikram
c504fa6a4a Minor matmul docs improvement 2024-07-18 15:21:18 -07:00
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
jax authors
85c30d2a86 Merge pull request #20021 from sharadmv:pallas-matmul-docs
PiperOrigin-RevId: 653070524
2024-07-16 20:16:04 -07:00
Sharad Vikram
10e09af7a0 Address changes 2024-07-16 19:25:32 -07:00
Sharad Vikram
ff62d5e229 Address changes 2024-07-16 19:24:56 -07:00
Justin Fu
6ba889c01c [Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
2024-07-16 18:13:02 -07:00
Sharad Vikram
39ec5dacb4 [Pallas TPU] Add matrix multiplication tutorial 2024-07-16 18:12:19 -07:00
jax authors
5ddec63a47 Merge pull request #22441 from gnecula:test_clean_hypothesis
PiperOrigin-RevId: 652919414
2024-07-16 11:32:46 -07:00
kaixih
0d387e0839 Update jax doc sdpa 2024-07-15 17:30:54 +00:00
jax authors
2b29a94255 Merge pull request #22375 from jakevdp:mypy-docs
PiperOrigin-RevId: 652511749
2024-07-15 09:52:07 -07:00
George Necula
d3454f374e Add some hypothesis testing utilities and developer documentation.
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
2024-07-15 17:05:32 +02:00
George Necula
be8e83adc1 [docs] Fix docs building error
The checkify APIs were mentioned in the jax.experimental.rst and also
in jax.experimental.checkify.rst.
2024-07-15 15:42:33 +01:00
jax authors
f60643801d Merge pull request #22370 from gnecula:pallas_unblocked
PiperOrigin-RevId: 651770174
2024-07-12 07:41:38 -07:00
George Necula
7c059d4630 [pallas] Document the indexing_mode=Unblocked()
In the process discovered that the padding in the interpreter
mode was with 0s. I changed it to NaN/minint to match the
padding for the blocked mode.
2024-07-12 12:39:10 +03:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
George Necula
ea548e7c86 [pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
2024-07-10 19:16:53 +03:00
Jake VanderPlas
f3b7aea283 DOC: improve mypy/pre-commit instructions 2024-07-10 09:06:03 -07:00
Tom Ward
ebfbd8ac0c Fix cuda custom call example to build with updated XLA FFI API.
PiperOrigin-RevId: 650977379
2024-07-10 05:29:58 -07:00
Vadym Matsishevskyi
fb3607c1d5 Use inclusion list configuration for local wheels.
Also some documentation improvements/clarifications.

This allows it to not remove unused local wheels from the dist directory to avoid conflicts.

PiperOrigin-RevId: 650697758
2024-07-09 11:25:31 -07:00
bion howard
1ace88bfba
Update quickstart.md
fix minor grammar typo
2024-07-09 12:48:51 -04:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
jax authors
0d4e0ecf65 Merge pull request #22271 from ayaka14732:lru-cache-6
PiperOrigin-RevId: 650203793
2024-07-08 04:39:58 -07:00
George Necula
3df602882c [shape_poly] Small improvement in the documentation
Added an example for equality constraints.
2024-07-06 08:10:55 +03:00
Ayaka
db32021182 Update persistent compilation cache doc 2024-07-05 19:43:04 +08:00
jax authors
1e141577e3 Merge pull request #21819 from keshavb96:compilation_cache_doc
PiperOrigin-RevId: 649350829
2024-07-04 02:59:53 -07:00
jax authors
db13e6fc0e Merge pull request #22119 from dfm:cond-linear
PiperOrigin-RevId: 648535400
2024-07-01 17:36:59 -07:00
jax authors
b669ab7bb1 Merge pull request #21925 from dfm:ffi-call
PiperOrigin-RevId: 648532673
2024-07-01 17:24:10 -07:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Sergei Lebedev
e80632e6fd Revived the workaround for not-expanding type aliases
The version here only works for modules with
``from __future__ import annotations``, but we can safely add that import
to all modules now, since the minimal Python version JAX supports is 3.10.

The worakround was previously removed in #3485.
2024-07-01 14:31:53 +01:00
jax authors
5fac179f2f Merge pull request #22134 from gnecula:pallas_doc
PiperOrigin-RevId: 648147118
2024-06-30 09:15:16 -07:00
George Necula
bfdf8f4bd3 [pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
2024-06-29 14:43:48 +03:00
Jake VanderPlas
671db54f44 doc: remove references to submodules that no longer exist 2024-06-28 12:39:14 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
George Necula
47f1b3de2c [export] Add documentation for debugging and for ensuring compatibility.
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
2024-06-28 08:36:55 +03:00
Dan Foreman-Mackey
dda6430f7c Add register_custom_call_target to xla_client API docs.
This function is (for better or worse) user facing for custom call
users. I think it's worth having this in the API docs.
2024-06-27 14:40:36 -04:00