18921 Commits

Author SHA1 Message Date
Peter Hawkins
dedd69f323 Add a bazel test that verifies that the jaxlib wheel builds. 2024-01-11 23:22:17 +00:00
jax authors
8a8cd6d01a Merge pull request #19321 from jakevdp:diagonal
PiperOrigin-RevId: 597660818
2024-01-11 14:55:50 -08:00
Sharad Vikram
598b46aab5 [Pallas/TPU] Open source "Splash Attention" (Sparse Flash Attention), a general purpose attention kernel where you can specify an attention mask using NumPy.
PiperOrigin-RevId: 597658315
2024-01-11 14:47:39 -08:00
Sharad Vikram
5258f7cdcc Fix all_gather_test
PiperOrigin-RevId: 597657438
2024-01-11 14:38:43 -08:00
Sharad Vikram
548bdd02a8 Add verbose kwarg to assertArraysEqual
PiperOrigin-RevId: 597650736
2024-01-11 14:11:53 -08:00
Sergei Lebedev
5b7a0d9c91 Pallas now uses MLIR Python builders to lower to Triton IR
This allows us to drop a dependency on the Triton Python package in the future,
and delegate ->ptx compilation to XLA.

PiperOrigin-RevId: 597640756
2024-01-11 13:33:26 -08:00
Jake VanderPlas
b08a010949 [array API] add jnp.linalg.diagonal 2024-01-11 12:52:15 -08:00
Peter Hawkins
35fc2ed8e0 Disable ASAN for several CUDA tests.
PiperOrigin-RevId: 597596726
2024-01-11 10:43:38 -08:00
jax authors
5c7ea22614 Merge pull request #19299 from jakevdp:array-api-simplify
PiperOrigin-RevId: 597580538
2024-01-11 09:51:02 -08:00
jax authors
35345e7c7f Merge pull request #19303 from mattjj:pallas-block-matrix-alignment
PiperOrigin-RevId: 597580076
2024-01-11 09:42:21 -08:00
jax authors
59ea9f3fde [triton] Use cuLaunchKernelEx instead of cuLaunchKernel
PiperOrigin-RevId: 597555083
2024-01-11 07:52:07 -08:00
Adam Paszke
8f771b4211 [Mosaic] Simplify the handling of dynamic indices in vector.load and store
This normalizes loads and stores with dynamic base indices into reference
slicing followed by statically indexed loads/stores. This should both simplify
the code (we only have to deal with dynamism in slicing) and improve performance
(we might offset the address once).

PiperOrigin-RevId: 597546106
2024-01-11 07:08:07 -08:00
Adam Paszke
ce00e10d9b [Pallas][Mosaic] Add support for nontrivial semaphore memrefs
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.

PiperOrigin-RevId: 597538913
2024-01-11 06:33:49 -08:00
Peter Hawkins
858fd52ac0 Fix jaxlib wheel build after removal of mosaic python files.
PiperOrigin-RevId: 597536465
2024-01-11 06:21:07 -08:00
Adam Paszke
57506b50c5 [Mosaic] Make sure to infer native tiling for results of TruncIOp that are fed into matmuls
This replicates the optimization we already apply while truncating floating point types.
Also, the heuristic used previously didn't include the tpu.matmul op, which could have
led to some performance degradation.

PiperOrigin-RevId: 597514672
2024-01-11 04:15:36 -08:00
Matthew Johnson
efe78c53ec improve block matrix alignment in pallas docs 2024-01-10 22:01:16 -08:00
jax authors
e6783b3d1e Update XLA dependency to use revision
fb2f766e6b.

PiperOrigin-RevId: 597414168
2024-01-10 19:15:43 -08:00
Skye Wanderman-Milne
43918b7d87 [jax] Add pybind for PjRtExecutable::GetCostAnalysis/Executable.cost_analysis
This makes cross-compiled `Compiled.cost_analysis` work.

PiperOrigin-RevId: 597411014
2024-01-10 18:51:55 -08:00
Hyeontaek Lim
f2e526dc78 Internal code cleanup for reducing private API access.
PiperOrigin-RevId: 597397571
2024-01-10 17:25:05 -08:00
Jake VanderPlas
c906f44ac1 array api: simplify some wrappers 2024-01-10 15:49:15 -08:00
jax authors
bbf2ab00a7 Merge pull request #19293 from jakevdp:jnp-pow
PiperOrigin-RevId: 597371608
2024-01-10 15:26:30 -08:00
Jake VanderPlas
1a39d8fdb2 [array API] implement jnp.pow; alias for jnp.power 2024-01-10 14:59:46 -08:00
jax authors
4e5430dca8 Merge pull request #19278 from jakevdp:jnp-bitwise
PiperOrigin-RevId: 597364394
2024-01-10 14:58:04 -08:00
Jake VanderPlas
4e55086dfb array api: add jnp.bitwise_* aliases 2024-01-10 14:22:20 -08:00
Jevin Jiang
57f05592dd [XLA:Mosaic] Support inputs and outputs in scf::ForOp and add tpu::AssumeLayoutOp to work around block argument as operand.
PiperOrigin-RevId: 597353171
2024-01-10 14:15:05 -08:00
jax authors
0c4b680271 Merge pull request #19274 from jakevdp:vecdot
PiperOrigin-RevId: 597347328
2024-01-10 13:53:32 -08:00
Sergei Lebedev
6174145386 Removed the Triton dependency from the BUILD file
PiperOrigin-RevId: 597336551
2024-01-10 13:17:47 -08:00
Jake VanderPlas
9890b23b0a Add jnp.vecdot 2024-01-10 13:11:37 -08:00
jax authors
7ad53a6e67 Merge pull request #19276 from jakevdp:mypy-version
PiperOrigin-RevId: 597332904
2024-01-10 13:09:23 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Eugene Zhulenev
ba4c2b1c7d [pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
Jake VanderPlas
6569b320b2 CI: bump mypy to version 1.8.0 2024-01-10 10:20:55 -08:00
jax authors
6f5acf0c0d Merge pull request #19289 from jakevdp:update-array-api-tests
PiperOrigin-RevId: 597281920
2024-01-10 10:04:29 -08:00
Jake VanderPlas
635b9b6029 CI: update array API tests pin 2024-01-10 09:27:57 -08:00
jax authors
f31d1e7599 Bump NCCL version on JAX OSS
PiperOrigin-RevId: 597257455
2024-01-10 08:28:15 -08:00
jax authors
adf05d520a Merge pull request #19282 from gnecula:poly_is_symb
PiperOrigin-RevId: 597170362
2024-01-10 01:04:21 -08:00
George Necula
df280a11b0 [shape_poly] Introduce is_symbolic_dim and deprecate is_poly_dim.
The old is_poly_dim seems to be used in a few places externally.
This was from the time when the symbolic dimensions were polynomials,
now we use the more generic term symbolic dimension or expression.

We introduce is_symbolic_dim and we export it through the jax.experimental.export.
We plan to make the entire shape_poly.py module private, and this is
a necessary step.
2024-01-10 10:10:30 +02:00
jax authors
88169cf9e5 Merge pull request #19275 from j2kun:main
PiperOrigin-RevId: 597148779
2024-01-09 23:06:02 -08:00
jax authors
cbb40c101d Update XLA dependency to use revision
dcb94c46b2.

PiperOrigin-RevId: 597109537
2024-01-09 19:34:41 -08:00
Nitin Srinivasan
b58772cdb4 Enable Bazel remote cache in macOS continuous builds
PiperOrigin-RevId: 597100776
2024-01-09 18:30:31 -08:00
Jeremy Kun
2e6e5da49b docs/pallas: remove list from out_specs 2024-01-09 15:47:18 -08:00
jax authors
f74faf9c9c Merge pull request #19279 from jakevdp:fix-jupytext
PiperOrigin-RevId: 597063608
2024-01-09 15:36:08 -08:00
Jake VanderPlas
93500a8477 lint: fix jupytext version 2024-01-09 15:28:17 -08:00
jax authors
77cb0f141c Merge pull request #19147 from 8bitmp3:jax-docs-jaxprs
PiperOrigin-RevId: 597057971
2024-01-09 15:21:43 -08:00
jax authors
38f23b49af Merge pull request #19277 from jakevdp:jupytext-version
PiperOrigin-RevId: 597057949
2024-01-09 15:13:40 -08:00
Jeremy Kun
4ecbed1322 docs/pallas: sync jupyter notebook 2024-01-09 15:00:42 -08:00
Jake VanderPlas
10eae3f93a CI: update jupytext version 2024-01-09 14:34:21 -08:00
Jeremy Kun
87f914e02f docs/pallas: fix mismatched pytree specs
Running the example as-is gives

```
ValueError: Pytree specs for `out_shape` and `out_specs` must match: PyTreeDef(*) vs. PyTreeDef((*,))
```

Giving a list argument to `out_shape` seems to fix the issue.
2024-01-09 13:56:28 -08:00
8bitmp3
66a845e0c6 Upgrade JAX internals 301 jaxpr language tutorial 2024-01-09 21:53:22 +00:00
jax authors
04f2c91399 Merge pull request #19092 from 8bitmp3:jax-docs-gradient-checkpointing
PiperOrigin-RevId: 597034325
2024-01-09 13:50:13 -08:00