22787 Commits

Author SHA1 Message Date
Justin Fu
2d74c6aa05 Add TritonCompilerParams for specifying compiler arguments instead of a dict.
PiperOrigin-RevId: 671081069
2024-09-04 13:32:25 -07:00
Sergei Lebedev
a8a55e0f2e Added pl.CompilerParams subclass for Mosaic GPU
PiperOrigin-RevId: 671066741
2024-09-04 12:48:34 -07:00
Vladimir Belitskiy
3672b633c3 Fix a deprecation warning for NumPy array conversion.
To address  https://github.com/google/jax/actions/runs/10654663500/job/29531268089#step:6:656

```
E   DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays.  The conversion of 536870912 to int16 will fail in the future.
E   For the old behavior, usually:
E       np.array(value).astype(dtype)
E   will give the desired result (the cast overflows).
```

PiperOrigin-RevId: 671064730
2024-09-04 12:41:28 -07:00
Yash Katariya
bf66e816dd Split physical axes by default when device kind is TPU v5 lite to allow for mesh shapes (2, 2) when there are 8 v5e devices on a 4x2 topology.
PiperOrigin-RevId: 671047455
2024-09-04 11:49:17 -07:00
jax authors
8278a72b88 Merge pull request #23430 from froystig:optimizers
PiperOrigin-RevId: 671030783
2024-09-04 11:03:42 -07:00
jax authors
6929211a20 Merge pull request #23434 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 671027735
2024-09-04 10:56:48 -07:00
jax authors
2d023db8a2 Merge pull request #23433 from jakevdp:rng-impl-repr
PiperOrigin-RevId: 671027668
2024-09-04 10:56:31 -07:00
jax authors
72e74d1ef6 Merge pull request #23364 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 671027520
2024-09-04 10:55:24 -07:00
Jake VanderPlas
d6394c0795 random.key_impl: improve repr of output 2024-09-04 10:10:31 -07:00
jax authors
f90558d317 Merge pull request #23248 from ROCm:rocm-ubuntu-images
PiperOrigin-RevId: 670983320
2024-09-04 08:39:49 -07:00
rajasekharporeddy
10893033b9 Remove unused docstring addition: _PRECISION_DOC 2024-09-04 19:57:51 +05:30
rajasekharporeddy
cb45fb426a Better docs for jax.numpy: log and log1p 2024-09-04 19:22:58 +05:30
jax authors
22be4eafca Merge pull request #23398 from damianoamatruda:fix-pytype-array
PiperOrigin-RevId: 670937582
2024-09-04 05:45:16 -07:00
Roy Frostig
8310a6ab1b update example optimizers library docstring
* JAXopt is being merged into Optax, so point only to Optax
* Update Optax's github repository URL
2024-09-03 23:40:47 -07:00
Jevin Jiang
c1d3c2db9f [Mosaic TPU] Fix mosaic alignment check in concatenate rule.
PiperOrigin-RevId: 670837792
2024-09-03 22:57:27 -07:00
jax authors
ebc6c18152 Merge pull request #23417 from pkgoogle:better_true_divide_doc
PiperOrigin-RevId: 670754635
2024-09-03 17:02:35 -07:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
jax authors
5c0ee1a3e9 Merge pull request #23422 from jakevdp:make-mesh-docs
PiperOrigin-RevId: 670742011
2024-09-03 16:16:04 -07:00
Jake VanderPlas
7569dd5438 Update sharded-computation doc to use make_mesh() 2024-09-03 16:04:23 -07:00
Piseth Ky
78212ae39e better true_divide and divide docs
doc wording update
2024-09-03 16:03:55 -07:00
Sergei Lebedev
1289640f09 Deprecated calling `jax.dlpack.from_dlpack` with a DLPack tensor
PiperOrigin-RevId: 670723176
2024-09-03 15:16:02 -07:00
jax authors
d0d7493aae Update XLA dependency to use revision
950df46440.

PiperOrigin-RevId: 670719765
2024-09-03 15:04:33 -07:00
Yash Katariya
252caebce3 Create jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None) API to make it easier to create a mesh and reduce a ton of boilerplate.
`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.

PiperOrigin-RevId: 670707995
2024-09-03 14:32:03 -07:00
Sergei Lebedev
9b31b73d9d Added basic pipelining to Pallas on Mosaic GPU
The implementation only allows at most one sequential axis at the moment.

PiperOrigin-RevId: 670687671
2024-09-03 13:37:46 -07:00
jax authors
cda2408e14 Merge pull request #23394 from gspschmid:gschmid/ffi-support-token
PiperOrigin-RevId: 670652970
2024-09-03 12:06:07 -07:00
jax authors
40bdbcdbc2 Merge pull request #23407 from hawkinsp:hardware
PiperOrigin-RevId: 670645993
2024-09-03 11:49:42 -07:00
Georg Stefan Schmid
24bb8ae443 [ffi] Add support for token inputs and outputs 2024-09-03 18:28:34 +00:00
Sergei Lebedev
9030aec097 Added a new Pallas Triton primitive -- `plgpu.debug_barrier`
Closes #23400.

PiperOrigin-RevId: 670636723
2024-09-03 11:27:12 -07:00
Peter Hawkins
f92d4e3e3d Add TPU v6e to the list of known TPUs.
JAX will warn if it sees a device ID on this list but the runtime doesn't find one.
2024-09-03 14:20:35 -04:00
jax authors
ff702cb249 Merge pull request #23396 from hawkinsp:ci
PiperOrigin-RevId: 670607474
2024-09-03 10:16:49 -07:00
jax authors
cd55d0c91b Merge pull request #23397 from jakevdp:dep-round
PiperOrigin-RevId: 670604941
2024-09-03 10:10:43 -07:00
Damiano Amatruda
87350b7128
Fix pytype errors and args for jax.Array methods 2024-09-03 17:06:45 +00:00
Jake VanderPlas
fd897745d3 Partial rollback of https://github.com/google/jax/pull/23353 as discussed in https://github.com/google/jax/pull/23353#issuecomment-2326604708
Reverts eed273c106af699efefc726eea1ff2b0f548f669

PiperOrigin-RevId: 670596159
2024-09-03 09:49:22 -07:00
jax authors
2a42bbe48a Merge pull request #22843 from jeertmans:scipy-special-fresnel
PiperOrigin-RevId: 670543225
2024-09-03 07:07:19 -07:00
Jake VanderPlas
f2ffe7f8f2 Deprecate jax.numpy.round_
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Peter Hawkins
bc415f9153 Relax test tolerances to fix CI failures on Mac ARM. 2024-09-03 09:45:28 -04:00
Sergei Lebedev
ccabd21084 Fixed rules where `sliding_window_length` was not forwarded
This is follow up to #23284.

PiperOrigin-RevId: 670531634
2024-09-03 06:24:01 -07:00
Chris Jones
7b161fb76c [jax:pallas] Use 64-bit indexing when necessary when lowering to Triton.
PiperOrigin-RevId: 670530776
2024-09-03 06:20:39 -07:00
Adam Paszke
4c3111bf26 [Mosaic GPU] Unbreak tests
I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially
seemed right because both expressions evalute to 2 for 1 :)

PiperOrigin-RevId: 670527107
2024-09-03 06:07:54 -07:00
jax authors
eed273c106 Merge pull request #23353 from jakevdp:lax-deps
PiperOrigin-RevId: 670523237
2024-09-03 05:59:26 -07:00
Chris Jones
a1fd582ad6 [jax:pallas] Simplify pointer offset calculation in Triton lowering.
PiperOrigin-RevId: 670499398
2024-09-03 04:25:06 -07:00
Jérome Eertmans
f9cb95ca08
feat(lib): add real-valued implementation of jax.scipy.special.fresnel
Add implementation, documentation, and tests, for both single-precision and double-precision floating-point arithmetic.
2024-09-03 09:50:19 +02:00
Fabian Pedregosa
530ed026b8 FIX typo on jax.numpy.where docstring
this was preventing the link to be correctly rendered in the webpage https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html

PiperOrigin-RevId: 670385290
2024-09-02 20:50:00 -07:00
jax authors
225a028f1d Merge pull request #23371 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 670380554
2024-09-02 20:27:24 -07:00
jax authors
281bfcdc62 Merge pull request #23387 from mattjj:shmap-leak-checker
PiperOrigin-RevId: 670380518
2024-09-02 20:27:08 -07:00
jax authors
826e661347 Merge pull request #23365 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 670380496
2024-09-02 20:26:07 -07:00
Sharad Vikram
443780e208 [Mosaic TPU] Add support for semaphore operands (inputs and outputs)
This enables writing async kernels for collectives or prefetching.

PiperOrigin-RevId: 670366575
2024-09-02 19:30:09 -07:00
jax authors
be087cf155 Merge pull request #23361 from froystig:optimizers
PiperOrigin-RevId: 670331545
2024-09-02 17:16:41 -07:00
jax authors
827d7ddfe9 Merge pull request #23386 from superbobry:maint-3
PiperOrigin-RevId: 670331365
2024-09-02 17:15:39 -07:00
Matthew Johnson
f2bef6bb5c tweak shmap implementation to work better with leak checker 2024-09-02 23:57:31 +00:00