1804 Commits

Author SHA1 Message Date
George Necula
6c5583d6aa [pallas] Document and test valid block shapes.
I have only added tests and documentation, will improve error
reporting separately.

For TPU we get a mix of errors from either the Pallas lowering or
from Mosaic. I plan to add lowering exception for all unsupported
cases, so that we have a better Python stack trace available.

For GPU, we get a RET_CHECK instead of a Python exception,
so I had to add skipTest. Will fix the error message separately.

In order to be able to put the test in pallas_test::PallasCallTest, I
moved the skipTest for TPU from the setUp to the individual tests that
need this.

PiperOrigin-RevId: 653195289
2024-07-17 05:17:20 -07:00
jax authors
85c30d2a86 Merge pull request #20021 from sharadmv:pallas-matmul-docs
PiperOrigin-RevId: 653070524
2024-07-16 20:16:04 -07:00
Sharad Vikram
10e09af7a0 Address changes 2024-07-16 19:25:32 -07:00
Sharad Vikram
ff62d5e229 Address changes 2024-07-16 19:24:56 -07:00
Justin Fu
6ba889c01c [Pallas] Add support for checkify in TPU execution mode.
PiperOrigin-RevId: 653045818
2024-07-16 18:13:02 -07:00
Sharad Vikram
39ec5dacb4 [Pallas TPU] Add matrix multiplication tutorial 2024-07-16 18:12:19 -07:00
jax authors
5ddec63a47 Merge pull request #22441 from gnecula:test_clean_hypothesis
PiperOrigin-RevId: 652919414
2024-07-16 11:32:46 -07:00
kaixih
0d387e0839 Update jax doc sdpa 2024-07-15 17:30:54 +00:00
jax authors
2b29a94255 Merge pull request #22375 from jakevdp:mypy-docs
PiperOrigin-RevId: 652511749
2024-07-15 09:52:07 -07:00
George Necula
d3454f374e Add some hypothesis testing utilities and developer documentation.
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
2024-07-15 17:05:32 +02:00
George Necula
be8e83adc1 [docs] Fix docs building error
The checkify APIs were mentioned in the jax.experimental.rst and also
in jax.experimental.checkify.rst.
2024-07-15 15:42:33 +01:00
jax authors
f60643801d Merge pull request #22370 from gnecula:pallas_unblocked
PiperOrigin-RevId: 651770174
2024-07-12 07:41:38 -07:00
George Necula
7c059d4630 [pallas] Document the indexing_mode=Unblocked()
In the process discovered that the padding in the interpreter
mode was with 0s. I changed it to NaN/minint to match the
padding for the blocked mode.
2024-07-12 12:39:10 +03:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
George Necula
ea548e7c86 [pallas] Add more documentation and tests for BlockSpec.
This PR deals with the default values for the parameters
of the `BlockSpec` constructor, and the mapped block dimensions.

Fix a bug where previously a missing block_shape while the
index_map was present was resulting in a crash.
2024-07-10 19:16:53 +03:00
Jake VanderPlas
f3b7aea283 DOC: improve mypy/pre-commit instructions 2024-07-10 09:06:03 -07:00
Tom Ward
ebfbd8ac0c Fix cuda custom call example to build with updated XLA FFI API.
PiperOrigin-RevId: 650977379
2024-07-10 05:29:58 -07:00
Vadym Matsishevskyi
fb3607c1d5 Use inclusion list configuration for local wheels.
Also some documentation improvements/clarifications.

This allows it to not remove unused local wheels from the dist directory to avoid conflicts.

PiperOrigin-RevId: 650697758
2024-07-09 11:25:31 -07:00
bion howard
1ace88bfba
Update quickstart.md
fix minor grammar typo
2024-07-09 12:48:51 -04:00
George Necula
f02d32c680 [pallas] Fix the interpreter for block_shape not dividing the overall shape
Before this change, the interpreter was failing with an MLIR
verification error because the body of the while loop returned
a padded output array.

This change allows us to expand the documentation of block specs
with the case for when block_shape does not divide the overall shape.
2024-07-09 16:10:22 +03:00
jax authors
0d4e0ecf65 Merge pull request #22271 from ayaka14732:lru-cache-6
PiperOrigin-RevId: 650203793
2024-07-08 04:39:58 -07:00
George Necula
3df602882c [shape_poly] Small improvement in the documentation
Added an example for equality constraints.
2024-07-06 08:10:55 +03:00
Ayaka
db32021182 Update persistent compilation cache doc 2024-07-05 19:43:04 +08:00
jax authors
1e141577e3 Merge pull request #21819 from keshavb96:compilation_cache_doc
PiperOrigin-RevId: 649350829
2024-07-04 02:59:53 -07:00
jax authors
db13e6fc0e Merge pull request #22119 from dfm:cond-linear
PiperOrigin-RevId: 648535400
2024-07-01 17:36:59 -07:00
jax authors
b669ab7bb1 Merge pull request #21925 from dfm:ffi-call
PiperOrigin-RevId: 648532673
2024-07-01 17:24:10 -07:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Dan Foreman-Mackey
6becf716f3 Remove linear parameter from lax.cond_p.
As far as I can tell, it seems like the `linear` parameter in the
`lax.cond_p` primitive only exists for historical reasons. It could be
used for type checking in `_cond_transpose`, but that was removed
because of #14026. With this in mind, we could stop tracking this
parameter as implemented in this PR, unless we expect that we'd want to
re-introduce the type checking in the future.
2024-07-01 10:25:42 -04:00
Dan Foreman-Mackey
e9b087d3a8 Add ffi_call function with a similar signature to pure_callback.
This could be useful for supporting the most common use cases for FFI custom
calls. It has several benefits over using the `Primitive` based approach, but
the biggest one (in my opinion) is that it doesn't require interacting with
`mlir` at all. It does have the limitation that transforms would need to be
registered using interfaces like `custom_vjp`, but many users of custom calls
already do that.

~~The easiest to-do item (I think) is to implement batching using a
`vectorized` parameter like `pure_callback`, but we could also think about more
sophisticated vmapping interfaces in the future.~~ Done.

The more difficult to-do is to think about how to support sharding, and we
might actually want to expose an interface similar to the one from
`custom_partitioning`. I have less experience with this part so I'll have to
think some more about it, and feedback would be appreciated!
2024-07-01 09:40:31 -04:00
Sergei Lebedev
e80632e6fd Revived the workaround for not-expanding type aliases
The version here only works for modules with
``from __future__ import annotations``, but we can safely add that import
to all modules now, since the minimal Python version JAX supports is 3.10.

The worakround was previously removed in #3485.
2024-07-01 14:31:53 +01:00
jax authors
5fac179f2f Merge pull request #22134 from gnecula:pallas_doc
PiperOrigin-RevId: 648147118
2024-06-30 09:15:16 -07:00
George Necula
bfdf8f4bd3 [pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
2024-06-29 14:43:48 +03:00
Jake VanderPlas
671db54f44 doc: remove references to submodules that no longer exist 2024-06-28 12:39:14 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
George Necula
47f1b3de2c [export] Add documentation for debugging and for ensuring compatibility.
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
2024-06-28 08:36:55 +03:00
Dan Foreman-Mackey
dda6430f7c Add register_custom_call_target to xla_client API docs.
This function is (for better or worse) user facing for custom call
users. I think it's worth having this in the API docs.
2024-06-27 14:40:36 -04:00
jax authors
00528b9858 libdevice.10.bc is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
2024-06-27 08:35:59 -07:00
Peng Wang
99d90b23ee Fixes an error in jax.export.Exported's docstring.
PiperOrigin-RevId: 647115474
2024-06-26 16:27:07 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Keshav
cf0b8fd93e minor edits 2024-06-26 12:52:01 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
jax authors
b8aa4c52ee Merge pull request #22091 from abhinavgoel95:patch-2
PiperOrigin-RevId: 646924044
2024-06-26 06:36:25 -07:00
Abhinav Goel
7b9636f63c
Update flags gpu_performance_tips.md 2024-06-25 08:53:43 -07:00
George Necula
8528f5127d [pallas] Break long lines in the Pallas docs
No content changes.
2024-06-25 13:30:17 +03:00
Dan Foreman-Mackey
df50d05aae Fix BUILD for CUDA custom call example in docs
With this PR, the example in `docs/cuda_custom_call` can now build and be
properly detected by Bazel.

PiperOrigin-RevId: 646149391
2024-06-24 10:50:48 -07:00
Neil Girdhar
56fdb42e9d Copy nn.{softmax,log_softmax} to scipy.special 2024-06-22 09:32:14 -04:00
jax authors
fc1e1d4a65 Add freshness metablock to JAX OSS docs.
PiperOrigin-RevId: 645508135
2024-06-21 14:50:49 -07:00
jax authors
e1b9b8e50e Merge pull request #21960 from hawkinsp:docs
PiperOrigin-RevId: 644426278
2024-06-18 10:08:45 -07:00
Lukas Geiger
8df518f7f8
Fix typo in jax.export migration guide 2024-06-18 16:27:18 +01:00
Peter Hawkins
3735a485b0 Remove [cpu] extra from installation instructions.
`pip install jax` should do the right thing by default now.
2024-06-18 11:12:23 -04:00