19762 Commits

Author SHA1 Message Date
Neil Girdhar
1e580457ba Repair various type errors 2024-03-13 15:13:56 -04:00
Sergei Lebedev
a8e2ee9b65 Log the exception if the callback passed to jax.*_callback raises
PiperOrigin-RevId: 615407343
2024-03-13 07:23:29 -07:00
Peter Hawkins
cf856ad4a9 Reverts 8e2a8b7b95e838947dcf581d146909d5c4128742
PiperOrigin-RevId: 615401711
2024-03-13 07:01:49 -07:00
Peter Hawkins
642f20de1c [JAX] Convert stablehlo to MLIR bytecode, not an MLIR string.
Bytecode is considerably more compact.

PiperOrigin-RevId: 615386276
2024-03-13 06:02:18 -07:00
Sergei Lebedev
f0c5051004 Added a Pallas GPU test for jnp.invert
PiperOrigin-RevId: 615369570
2024-03-13 04:48:38 -07:00
Sergei Lebedev
926e673f61 Removed unused _pack_indices
PiperOrigin-RevId: 615340548
2024-03-13 02:45:25 -07:00
Sergei Lebedev
168f30a4c8 Use int.bit_length() in next_power_of_2
See https://docs.python.org/3/library/stdtypes.html#int.bit_length.

PiperOrigin-RevId: 615340317
2024-03-13 02:45:08 -07:00
Sergei Lebedev
a7964445e6 Fixed the signature of the fallback get_compute_capability
PiperOrigin-RevId: 615338312
2024-03-13 02:34:41 -07:00
jax authors
187b7aa8e6 Update XLA dependency to use revision
d7b3099f62.

PiperOrigin-RevId: 615284678
2024-03-12 22:08:25 -07:00
jax authors
60bf38bde9 Merge pull request #20128 from shuhand0:dev/shuhan/ci2
PiperOrigin-RevId: 615231412
2024-03-12 17:45:19 -07:00
Shuhan Ding
5a93e15bd7
add to tests/BUILD 2024-03-12 17:17:20 -07:00
jax authors
f6b7207e43 Merge pull request #20215 from jakevdp:fix-eager-test
PiperOrigin-RevId: 615216265
2024-03-12 16:56:10 -07:00
Jake VanderPlas
2ba9b45277 [key-reuse] fix flaky test 2024-03-12 16:49:16 -07:00
Yash Katariya
64622d6a64 Don't index into numpy array if the sharding is fully replicated.
PiperOrigin-RevId: 615202739
2024-03-12 16:04:52 -07:00
Sergei Lebedev
75dbd30a93 Added tests for more Pallas GPU lowering rules
As expected, this uncovered lots of small typos all over the place.

PiperOrigin-RevId: 615191542
2024-03-12 15:28:01 -07:00
jax authors
2827564e86 Merge pull request #20204 from mattjj:shmap-tutorial-typos-2
PiperOrigin-RevId: 615147238
2024-03-12 13:12:54 -07:00
Matthew Johnson
bf495bf5cb fix typos in shard_map tutorial 2024-03-12 11:58:26 -07:00
jax authors
f8af4ef816 Merge pull request #20201 from mattjj:attrs-notimplemented-error
PiperOrigin-RevId: 615123400
2024-03-12 11:56:05 -07:00
jax authors
e809683013 Merge pull request #20199 from jakevdp:solve-tests
PiperOrigin-RevId: 615115601
2024-03-12 11:31:54 -07:00
Matthew Johnson
c1dd67b1fe [attrs] 'too many values to unpack' is code for notimplementederror 2024-03-12 11:12:57 -07:00
Jake VanderPlas
46fab93371 Fix jnp.solve test for numpy 2.0 2024-03-12 10:15:38 -07:00
jax authors
d448f6670b Merge pull request #20194 from Micky774:array_api_error_type
PiperOrigin-RevId: 615084657
2024-03-12 10:10:56 -07:00
jax authors
11efc9b93b Merge pull request #20197 from ayaka14732:fix-typo-1
PiperOrigin-RevId: 615061021
2024-03-12 09:08:19 -07:00
jax authors
75666ec945 Merge pull request #20185 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 615057846
2024-03-12 09:07:57 -07:00
jax authors
e3d9d9d82b Merge pull request #20189 from rajasekharporeddy:rajasekharporeddy-patch-1
PiperOrigin-RevId: 615057844
2024-03-12 08:57:42 -07:00
Ayaka
cf62573d42 Fix typo in multi_process.md 2024-03-12 14:34:10 -01:00
Sergei Lebedev
ae1cfc21c3 Added a lowering rule for lax.sign_p and improved test coverage for binary ops
Closes #17317

PiperOrigin-RevId: 615038353
2024-03-12 07:48:43 -07:00
Meekail Zain
9924a0cb65 Update 2024-03-12 12:56:22 +00:00
Adam Paszke
d0eae05741 Add a test for grid overflows in dynamic grid lowering.
PiperOrigin-RevId: 614980113
2024-03-12 03:38:24 -07:00
jax authors
9dbf758a3d Merge pull request #20182 from superbobry:bye-xmap
PiperOrigin-RevId: 614965089
2024-03-12 02:32:21 -07:00
rajasekharporeddy
e94299c946
Fix Typos in CHANGELOG.md
This PR fixes the typos in Change log documentation
2024-03-12 13:57:07 +05:30
jax authors
98ad6ef057 Add debug logging for autotune profile sharing.
PiperOrigin-RevId: 614915339
2024-03-11 22:46:56 -07:00
jax authors
2804cc78bf Update XLA dependency to use revision
3b7ec75e03.

PiperOrigin-RevId: 614913376
2024-03-11 22:35:36 -07:00
rajasekharporeddy
e32bac4b3a Fixed Typos in JEP doc files 2024-03-12 09:58:13 +05:30
Jevin Jiang
30208fa9cc [XLA:Mosaic] Support strided load/store memref with arbitrary shape as long as last dim size is 128 and dtype is 32bit.
PiperOrigin-RevId: 614862128
2024-03-11 18:22:11 -07:00
jax authors
63538771b5 Merge pull request #20183 from jakevdp:key-reuse-concatenate
PiperOrigin-RevId: 614851106
2024-03-11 17:37:59 -07:00
jax authors
8ae93d5a9a Merge pull request #20181 from jakevdp:reuse-signature-repr
PiperOrigin-RevId: 614821824
2024-03-11 15:53:19 -07:00
jax authors
fe44afc0fc Merge pull request #20161 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 614813533
2024-03-11 15:27:06 -07:00
Jake VanderPlas
6cf740ceb1 [key reuse] improve repr for signatures 2024-03-11 15:17:08 -07:00
Jake VanderPlas
3eff032aba [key reuse] define rule for lax.concatenate 2024-03-11 15:06:59 -07:00
Sergei Lebedev
0bd7070000 Fixed lowering of binary ops for signed dtypes
All integers in Trition are signless, so we need to manually forward the
signedness of the abstract values.

I wonder if we should avoid relying on MLIR types altogether and change _cast
and similar APIs to accept JAX dtypes instead?

PiperOrigin-RevId: 614803683
2024-03-11 14:54:54 -07:00
Sergei Lebedev
88853c1fe3 Add a deprecation warning to the xmap function and the corresponding tutorial 2024-03-11 21:31:32 +00:00
Jian Li
b6e985ffe7 Add int4 test to ArrayImpl.
PiperOrigin-RevId: 614778550
2024-03-11 13:40:11 -07:00
Sergei Lebedev
778933dfda Removed inspect.signature() call from jaxlib.triton.dialect.ScanOp
PiperOrigin-RevId: 614772594
2024-03-11 13:30:41 -07:00
jax authors
93e5bbe039 A fusion flag for each operand is set to false by default. A custom call writer is expected to turn them on if he expects those fusions to be profitable. The operand may not fuse despite the flag being turned to true because of other constrains such as estimated memory required after fusion.
PiperOrigin-RevId: 614772584
2024-03-11 13:21:06 -07:00
rajasekharporeddy
61c64c10f8 Fixed Several Typos
Fixed Typos in JEP doc files

Revert "Fixed Typos in JEP doc files"

This reverts commit c2a16950e0fc1b32971168501d183991e2394b5d.

revert two changes

reverted one change in advanced-autodiff

revert one change in parallelism

sync notebooks
2024-03-12 00:37:46 +05:30
Goran Flegar
53364b438c Integrate Triton up to [bfb8e413](bfb8e413b0)
PiperOrigin-RevId: 614740360
2024-03-11 11:43:46 -07:00
Shuhan Ding
4c4bcde723
pass set of lax_numpy_test 2024-03-11 11:18:58 -07:00
Adam Paszke
71ec6e33ca Make pl.num_programs lowering take the vmapped axes into account
Otherwise the size of the wrong axis is returned.

PiperOrigin-RevId: 614677218
2024-03-11 08:41:20 -07:00
Peter Hawkins
de455e7003 Fix small bug in random_test.
unsafe_buffer_pointer() and on_device_size_in_bytes() are methods, not properties, so presumably the test intended to call them rather than test equality of the bound methods.

PiperOrigin-RevId: 614651090
2024-03-11 07:04:58 -07:00