Peter Hawkins
dedd69f323
Add a bazel test that verifies that the jaxlib wheel builds.
2024-01-11 23:22:17 +00:00
jax authors
8a8cd6d01a
Merge pull request #19321 from jakevdp:diagonal
...
PiperOrigin-RevId: 597660818
2024-01-11 14:55:50 -08:00
Sharad Vikram
598b46aab5
[Pallas/TPU] Open source "Splash Attention" (Sparse Flash Attention), a general purpose attention kernel where you can specify an attention mask using NumPy.
...
PiperOrigin-RevId: 597658315
2024-01-11 14:47:39 -08:00
Sharad Vikram
5258f7cdcc
Fix all_gather_test
...
PiperOrigin-RevId: 597657438
2024-01-11 14:38:43 -08:00
Sharad Vikram
548bdd02a8
Add verbose kwarg to assertArraysEqual
...
PiperOrigin-RevId: 597650736
2024-01-11 14:11:53 -08:00
Sergei Lebedev
5b7a0d9c91
Pallas now uses MLIR Python builders to lower to Triton IR
...
This allows us to drop a dependency on the Triton Python package in the future,
and delegate ->ptx compilation to XLA.
PiperOrigin-RevId: 597640756
2024-01-11 13:33:26 -08:00
Jake VanderPlas
b08a010949
[array API] add jnp.linalg.diagonal
2024-01-11 12:52:15 -08:00
Peter Hawkins
35fc2ed8e0
Disable ASAN for several CUDA tests.
...
PiperOrigin-RevId: 597596726
2024-01-11 10:43:38 -08:00
jax authors
5c7ea22614
Merge pull request #19299 from jakevdp:array-api-simplify
...
PiperOrigin-RevId: 597580538
2024-01-11 09:51:02 -08:00
jax authors
35345e7c7f
Merge pull request #19303 from mattjj:pallas-block-matrix-alignment
...
PiperOrigin-RevId: 597580076
2024-01-11 09:42:21 -08:00
jax authors
59ea9f3fde
[triton] Use cuLaunchKernelEx instead of cuLaunchKernel
...
PiperOrigin-RevId: 597555083
2024-01-11 07:52:07 -08:00
Adam Paszke
8f771b4211
[Mosaic] Simplify the handling of dynamic indices in vector.load and store
...
This normalizes loads and stores with dynamic base indices into reference
slicing followed by statically indexed loads/stores. This should both simplify
the code (we only have to deal with dynamism in slicing) and improve performance
(we might offset the address once).
PiperOrigin-RevId: 597546106
2024-01-11 07:08:07 -08:00
Adam Paszke
ce00e10d9b
[Pallas][Mosaic] Add support for nontrivial semaphore memrefs
...
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.
PiperOrigin-RevId: 597538913
2024-01-11 06:33:49 -08:00
Peter Hawkins
858fd52ac0
Fix jaxlib wheel build after removal of mosaic python files.
...
PiperOrigin-RevId: 597536465
2024-01-11 06:21:07 -08:00
Adam Paszke
57506b50c5
[Mosaic] Make sure to infer native tiling for results of TruncIOp that are fed into matmuls
...
This replicates the optimization we already apply while truncating floating point types.
Also, the heuristic used previously didn't include the tpu.matmul op, which could have
led to some performance degradation.
PiperOrigin-RevId: 597514672
2024-01-11 04:15:36 -08:00
Matthew Johnson
efe78c53ec
improve block matrix alignment in pallas docs
2024-01-10 22:01:16 -08:00
jax authors
e6783b3d1e
Update XLA dependency to use revision
...
fb2f766e6b
.
PiperOrigin-RevId: 597414168
2024-01-10 19:15:43 -08:00
Skye Wanderman-Milne
43918b7d87
[jax] Add pybind for PjRtExecutable::GetCostAnalysis
/Executable.cost_analysis
...
This makes cross-compiled `Compiled.cost_analysis` work.
PiperOrigin-RevId: 597411014
2024-01-10 18:51:55 -08:00
Hyeontaek Lim
f2e526dc78
Internal code cleanup for reducing private API access.
...
PiperOrigin-RevId: 597397571
2024-01-10 17:25:05 -08:00
Jake VanderPlas
c906f44ac1
array api: simplify some wrappers
2024-01-10 15:49:15 -08:00
jax authors
bbf2ab00a7
Merge pull request #19293 from jakevdp:jnp-pow
...
PiperOrigin-RevId: 597371608
2024-01-10 15:26:30 -08:00
Jake VanderPlas
1a39d8fdb2
[array API] implement jnp.pow; alias for jnp.power
2024-01-10 14:59:46 -08:00
jax authors
4e5430dca8
Merge pull request #19278 from jakevdp:jnp-bitwise
...
PiperOrigin-RevId: 597364394
2024-01-10 14:58:04 -08:00
Jake VanderPlas
4e55086dfb
array api: add jnp.bitwise_* aliases
2024-01-10 14:22:20 -08:00
Jevin Jiang
57f05592dd
[XLA:Mosaic] Support inputs and outputs in scf::ForOp and add tpu::AssumeLayoutOp to work around block argument as operand.
...
PiperOrigin-RevId: 597353171
2024-01-10 14:15:05 -08:00
jax authors
0c4b680271
Merge pull request #19274 from jakevdp:vecdot
...
PiperOrigin-RevId: 597347328
2024-01-10 13:53:32 -08:00
Sergei Lebedev
6174145386
Removed the Triton dependency from the BUILD file
...
PiperOrigin-RevId: 597336551
2024-01-10 13:17:47 -08:00
Jake VanderPlas
9890b23b0a
Add jnp.vecdot
2024-01-10 13:11:37 -08:00
jax authors
7ad53a6e67
Merge pull request #19276 from jakevdp:mypy-version
...
PiperOrigin-RevId: 597332904
2024-01-10 13:09:23 -08:00
Tomás Longeri
027c24e602
[Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
...
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Eugene Zhulenev
ba4c2b1c7d
[pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
...
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
Jake VanderPlas
6569b320b2
CI: bump mypy to version 1.8.0
2024-01-10 10:20:55 -08:00
jax authors
6f5acf0c0d
Merge pull request #19289 from jakevdp:update-array-api-tests
...
PiperOrigin-RevId: 597281920
2024-01-10 10:04:29 -08:00
Jake VanderPlas
635b9b6029
CI: update array API tests pin
2024-01-10 09:27:57 -08:00
jax authors
f31d1e7599
Bump NCCL version on JAX OSS
...
PiperOrigin-RevId: 597257455
2024-01-10 08:28:15 -08:00
jax authors
adf05d520a
Merge pull request #19282 from gnecula:poly_is_symb
...
PiperOrigin-RevId: 597170362
2024-01-10 01:04:21 -08:00
George Necula
df280a11b0
[shape_poly] Introduce is_symbolic_dim and deprecate is_poly_dim.
...
The old is_poly_dim seems to be used in a few places externally.
This was from the time when the symbolic dimensions were polynomials,
now we use the more generic term symbolic dimension or expression.
We introduce is_symbolic_dim and we export it through the jax.experimental.export.
We plan to make the entire shape_poly.py module private, and this is
a necessary step.
2024-01-10 10:10:30 +02:00
jax authors
88169cf9e5
Merge pull request #19275 from j2kun:main
...
PiperOrigin-RevId: 597148779
2024-01-09 23:06:02 -08:00
jax authors
cbb40c101d
Update XLA dependency to use revision
...
dcb94c46b2
.
PiperOrigin-RevId: 597109537
2024-01-09 19:34:41 -08:00
Nitin Srinivasan
b58772cdb4
Enable Bazel remote cache in macOS continuous builds
...
PiperOrigin-RevId: 597100776
2024-01-09 18:30:31 -08:00
Jeremy Kun
2e6e5da49b
docs/pallas: remove list from out_specs
2024-01-09 15:47:18 -08:00
jax authors
f74faf9c9c
Merge pull request #19279 from jakevdp:fix-jupytext
...
PiperOrigin-RevId: 597063608
2024-01-09 15:36:08 -08:00
Jake VanderPlas
93500a8477
lint: fix jupytext version
2024-01-09 15:28:17 -08:00
jax authors
77cb0f141c
Merge pull request #19147 from 8bitmp3:jax-docs-jaxprs
...
PiperOrigin-RevId: 597057971
2024-01-09 15:21:43 -08:00
jax authors
38f23b49af
Merge pull request #19277 from jakevdp:jupytext-version
...
PiperOrigin-RevId: 597057949
2024-01-09 15:13:40 -08:00
Jeremy Kun
4ecbed1322
docs/pallas: sync jupyter notebook
2024-01-09 15:00:42 -08:00
Jake VanderPlas
10eae3f93a
CI: update jupytext version
2024-01-09 14:34:21 -08:00
Jeremy Kun
87f914e02f
docs/pallas: fix mismatched pytree specs
...
Running the example as-is gives
```
ValueError: Pytree specs for `out_shape` and `out_specs` must match: PyTreeDef(*) vs. PyTreeDef((*,))
```
Giving a list argument to `out_shape` seems to fix the issue.
2024-01-09 13:56:28 -08:00
8bitmp3
66a845e0c6
Upgrade JAX internals 301 jaxpr language tutorial
2024-01-09 21:53:22 +00:00
jax authors
04f2c91399
Merge pull request #19092 from 8bitmp3:jax-docs-gradient-checkpointing
...
PiperOrigin-RevId: 597034325
2024-01-09 13:50:13 -08:00