Justin Fu
2d74c6aa05
Add TritonCompilerParams for specifying compiler arguments instead of a dict.
...
PiperOrigin-RevId: 671081069
2024-09-04 13:32:25 -07:00
Sergei Lebedev
a8a55e0f2e
Added pl.CompilerParams subclass for Mosaic GPU
...
PiperOrigin-RevId: 671066741
2024-09-04 12:48:34 -07:00
Vladimir Belitskiy
3672b633c3
Fix a deprecation warning for NumPy array conversion.
...
To address https://github.com/google/jax/actions/runs/10654663500/job/29531268089#step:6:656
```
E DeprecationWarning: NumPy will stop allowing conversion of out-of-bound Python integers to integer arrays. The conversion of 536870912 to int16 will fail in the future.
E For the old behavior, usually:
E np.array(value).astype(dtype)
E will give the desired result (the cast overflows).
```
PiperOrigin-RevId: 671064730
2024-09-04 12:41:28 -07:00
Yash Katariya
bf66e816dd
Split physical axes by default when device kind is TPU v5 lite
to allow for mesh shapes (2, 2) when there are 8 v5e devices on a 4x2 topology.
...
PiperOrigin-RevId: 671047455
2024-09-04 11:49:17 -07:00
jax authors
8278a72b88
Merge pull request #23430 from froystig:optimizers
...
PiperOrigin-RevId: 671030783
2024-09-04 11:03:42 -07:00
jax authors
6929211a20
Merge pull request #23434 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 671027735
2024-09-04 10:56:48 -07:00
jax authors
2d023db8a2
Merge pull request #23433 from jakevdp:rng-impl-repr
...
PiperOrigin-RevId: 671027668
2024-09-04 10:56:31 -07:00
jax authors
72e74d1ef6
Merge pull request #23364 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 671027520
2024-09-04 10:55:24 -07:00
Jake VanderPlas
d6394c0795
random.key_impl: improve repr of output
2024-09-04 10:10:31 -07:00
jax authors
f90558d317
Merge pull request #23248 from ROCm:rocm-ubuntu-images
...
PiperOrigin-RevId: 670983320
2024-09-04 08:39:49 -07:00
rajasekharporeddy
10893033b9
Remove unused docstring addition: _PRECISION_DOC
2024-09-04 19:57:51 +05:30
rajasekharporeddy
cb45fb426a
Better docs for jax.numpy: log and log1p
2024-09-04 19:22:58 +05:30
jax authors
22be4eafca
Merge pull request #23398 from damianoamatruda:fix-pytype-array
...
PiperOrigin-RevId: 670937582
2024-09-04 05:45:16 -07:00
Roy Frostig
8310a6ab1b
update example optimizers library docstring
...
* JAXopt is being merged into Optax, so point only to Optax
* Update Optax's github repository URL
2024-09-03 23:40:47 -07:00
Jevin Jiang
c1d3c2db9f
[Mosaic TPU] Fix mosaic alignment check in concatenate rule.
...
PiperOrigin-RevId: 670837792
2024-09-03 22:57:27 -07:00
jax authors
ebc6c18152
Merge pull request #23417 from pkgoogle:better_true_divide_doc
...
PiperOrigin-RevId: 670754635
2024-09-03 17:02:35 -07:00
Yash Katariya
e1b497078e
Rename jtu.create_global_mesh
to jtu.create_mesh
and use jax.make_mesh
inside jtu.create_mesh
to get maximum test coverage of the new API.
...
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
jax authors
5c0ee1a3e9
Merge pull request #23422 from jakevdp:make-mesh-docs
...
PiperOrigin-RevId: 670742011
2024-09-03 16:16:04 -07:00
Jake VanderPlas
7569dd5438
Update sharded-computation doc to use make_mesh()
2024-09-03 16:04:23 -07:00
Piseth Ky
78212ae39e
better true_divide and divide docs
...
doc wording update
2024-09-03 16:03:55 -07:00
Sergei Lebedev
1289640f09
Deprecated calling `jax.dlpack.from_dlpack
` with a DLPack tensor
...
PiperOrigin-RevId: 670723176
2024-09-03 15:16:02 -07:00
jax authors
d0d7493aae
Update XLA dependency to use revision
...
950df46440
.
PiperOrigin-RevId: 670719765
2024-09-03 15:04:33 -07:00
Yash Katariya
252caebce3
Create jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None)
API to make it easier to create a mesh and reduce a ton of boilerplate.
...
`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.
PiperOrigin-RevId: 670707995
2024-09-03 14:32:03 -07:00
Sergei Lebedev
9b31b73d9d
Added basic pipelining to Pallas on Mosaic GPU
...
The implementation only allows at most one sequential axis at the moment.
PiperOrigin-RevId: 670687671
2024-09-03 13:37:46 -07:00
jax authors
cda2408e14
Merge pull request #23394 from gspschmid:gschmid/ffi-support-token
...
PiperOrigin-RevId: 670652970
2024-09-03 12:06:07 -07:00
jax authors
40bdbcdbc2
Merge pull request #23407 from hawkinsp:hardware
...
PiperOrigin-RevId: 670645993
2024-09-03 11:49:42 -07:00
Georg Stefan Schmid
24bb8ae443
[ffi] Add support for token inputs and outputs
2024-09-03 18:28:34 +00:00
Sergei Lebedev
9030aec097
Added a new Pallas Triton primitive -- `plgpu.debug_barrier
`
...
Closes #23400 .
PiperOrigin-RevId: 670636723
2024-09-03 11:27:12 -07:00
Peter Hawkins
f92d4e3e3d
Add TPU v6e to the list of known TPUs.
...
JAX will warn if it sees a device ID on this list but the runtime doesn't find one.
2024-09-03 14:20:35 -04:00
jax authors
ff702cb249
Merge pull request #23396 from hawkinsp:ci
...
PiperOrigin-RevId: 670607474
2024-09-03 10:16:49 -07:00
jax authors
cd55d0c91b
Merge pull request #23397 from jakevdp:dep-round
...
PiperOrigin-RevId: 670604941
2024-09-03 10:10:43 -07:00
Damiano Amatruda
87350b7128
Fix pytype errors and args for jax.Array methods
2024-09-03 17:06:45 +00:00
Jake VanderPlas
fd897745d3
Partial rollback of https://github.com/google/jax/pull/23353 as discussed in https://github.com/google/jax/pull/23353#issuecomment-2326604708
...
Reverts eed273c106af699efefc726eea1ff2b0f548f669
PiperOrigin-RevId: 670596159
2024-09-03 09:49:22 -07:00
jax authors
2a42bbe48a
Merge pull request #22843 from jeertmans:scipy-special-fresnel
...
PiperOrigin-RevId: 670543225
2024-09-03 07:07:19 -07:00
Jake VanderPlas
f2ffe7f8f2
Deprecate jax.numpy.round_
...
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Peter Hawkins
bc415f9153
Relax test tolerances to fix CI failures on Mac ARM.
2024-09-03 09:45:28 -04:00
Sergei Lebedev
ccabd21084
Fixed rules where `sliding_window_length
` was not forwarded
...
This is follow up to #23284 .
PiperOrigin-RevId: 670531634
2024-09-03 06:24:01 -07:00
Chris Jones
7b161fb76c
[jax:pallas] Use 64-bit indexing when necessary when lowering to Triton.
...
PiperOrigin-RevId: 670530776
2024-09-03 06:20:39 -07:00
Adam Paszke
4c3111bf26
[Mosaic GPU] Unbreak tests
...
I mistakenly checked for `amount + 1` instead of `amount * 2`. It initially
seemed right because both expressions evalute to 2 for 1 :)
PiperOrigin-RevId: 670527107
2024-09-03 06:07:54 -07:00
jax authors
eed273c106
Merge pull request #23353 from jakevdp:lax-deps
...
PiperOrigin-RevId: 670523237
2024-09-03 05:59:26 -07:00
Chris Jones
a1fd582ad6
[jax:pallas] Simplify pointer offset calculation in Triton lowering.
...
PiperOrigin-RevId: 670499398
2024-09-03 04:25:06 -07:00
Jérome Eertmans
f9cb95ca08
feat(lib): add real-valued implementation of jax.scipy.special.fresnel
...
Add implementation, documentation, and tests, for both single-precision and double-precision floating-point arithmetic.
2024-09-03 09:50:19 +02:00
Fabian Pedregosa
530ed026b8
FIX typo on jax.numpy.where docstring
...
this was preventing the link to be correctly rendered in the webpage https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html
PiperOrigin-RevId: 670385290
2024-09-02 20:50:00 -07:00
jax authors
225a028f1d
Merge pull request #23371 from rajasekharporeddy:testbranch3
...
PiperOrigin-RevId: 670380554
2024-09-02 20:27:24 -07:00
jax authors
281bfcdc62
Merge pull request #23387 from mattjj:shmap-leak-checker
...
PiperOrigin-RevId: 670380518
2024-09-02 20:27:08 -07:00
jax authors
826e661347
Merge pull request #23365 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 670380496
2024-09-02 20:26:07 -07:00
Sharad Vikram
443780e208
[Mosaic TPU] Add support for semaphore operands (inputs and outputs)
...
This enables writing async kernels for collectives or prefetching.
PiperOrigin-RevId: 670366575
2024-09-02 19:30:09 -07:00
jax authors
be087cf155
Merge pull request #23361 from froystig:optimizers
...
PiperOrigin-RevId: 670331545
2024-09-02 17:16:41 -07:00
jax authors
827d7ddfe9
Merge pull request #23386 from superbobry:maint-3
...
PiperOrigin-RevId: 670331365
2024-09-02 17:15:39 -07:00
Matthew Johnson
f2bef6bb5c
tweak shmap implementation to work better with leak checker
2024-09-02 23:57:31 +00:00