23004 Commits

Author SHA1 Message Date
Sharad Vikram
9d3762bd47 [Pallas] Add design note for async ops on TPU 2024-09-17 12:45:29 -07:00
Jevin Jiang
d27fce6981 [Pallas TPU] Fix dtype_bitwidth for int in util.
PiperOrigin-RevId: 675357560
2024-09-16 18:00:13 -07:00
Peter Hawkins
940860625e Remove code that existed to support jaxlib < 0.4.32.
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
jax authors
df385b6ad3 Merge pull request #23652 from flferretti:feature/scalar_first_scipy_spatial
PiperOrigin-RevId: 675286624
2024-09-16 14:17:19 -07:00
Peter Hawkins
29163fcefd Update XLA dependency to use revision
90be6e3a11.

PiperOrigin-RevId: 675267262
2024-09-16 13:22:10 -07:00
jax authors
fc5e7b150d Merge pull request #23672 from hawkinsp:xla
PiperOrigin-RevId: 675265483
2024-09-16 13:16:48 -07:00
Peter Hawkins
4fac852a90 Remove XLA tanh fix cherry-pick, to avoid CI breakages when the XLA commit is
bumped.
2024-09-16 20:05:25 +00:00
Filippo Luca Ferretti
2ff26ff3e0 Add scalar_first argument to jax.scipy.spatial.transform.Rotation.as_quat 2024-09-16 21:57:55 +02:00
jax authors
8c84e1637e Merge pull request #23670 from hawkinsp:postrelease
PiperOrigin-RevId: 675252501
2024-09-16 12:39:30 -07:00
Vadym Matsishevskyi
8804be0229 Add Python 3.130rc2 support to the build.
This PR depends on https://github.com/openxla/xla/pull/17169. The change does not fail existing builds, but to be able to use python 3.13 functionality in jax the corresponding XLA pr needs to land first and get integrated with JAX (happens automatically).

PiperOrigin-RevId: 675243989
2024-09-16 12:14:32 -07:00
jax authors
90f532a9ac Merge pull request #23556 from selamw1:docstring_frombuffer
PiperOrigin-RevId: 675240113
2024-09-16 12:03:53 -07:00
Yash Katariya
8ab66c8103 Fix the TPU and GPU nightly install instructions.
PiperOrigin-RevId: 675233702
2024-09-16 11:46:58 -07:00
Peter Hawkins
ae0e403c60 Merge release/0.4.33 into main and update version numbers. 2024-09-16 18:46:24 +00:00
jax authors
c60e9f08b6 Merge pull request #23665 from jakevdp:cleanup-string
PiperOrigin-RevId: 675229668
2024-09-16 11:37:15 -07:00
jax authors
543e802742 Merge pull request #23655 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 675221097
2024-09-16 11:16:24 -07:00
selamw1
7dde9b2909 frombuffer_docstring_added
description_changed_examp_added

doc_byte_fixed

discription_modified
2024-09-16 11:11:52 -07:00
Jake VanderPlas
321b4fbfbf Remove unused string global 2024-09-16 11:03:11 -07:00
rajasekharporeddy
0942458d71 Improve doc for jnp.logaddexp2 2024-09-16 23:16:50 +05:30
Sergei Lebedev
8c39d0373a Added a new primitive for copying GMEM<->SMEM in Pallas Mosaic GPU kernels
The copy is async and needs to be awaited via `plgpu.wait_inflight(...)` for
SMEM->GMEM copies and via `plgpu.wait(barrier)` for GMEM->SMEM copies.

I decided to have distinct functions for SMEM->GMEM and GMEM->SMEM copies
and for the ways to await the result, because the underlying Mosaic GPU
APIs (and PTX ISA) *are* in fact very different.

PiperOrigin-RevId: 675155317
2024-09-16 08:18:46 -07:00
jax authors
8a867c12f0 Merge pull request #23654 from jakevdp:atleast-nd-docs
PiperOrigin-RevId: 675140904
2024-09-16 07:32:19 -07:00
jax authors
dfa4e2413c Merge pull request #23643 from TomAugspurger:fix/doc-typos
PiperOrigin-RevId: 675133675
2024-09-16 07:07:26 -07:00
Peter Hawkins
80e1c94de6 Prepare for v0.4.33 release.
This release is branched off the v0.4.32 release, with two changes:
a) a fixed libtpu pin, and
b) a patch to revert an F64 tanh issue on CPU.
2024-09-16 13:30:35 +00:00
Jake VanderPlas
d5ceb78708 Better documentation for jnp.atleast_*d 2024-09-16 06:05:19 -07:00
jax authors
a8bd6eb4d1 Merge pull request #23623 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 675115862
2024-09-16 05:59:56 -07:00
jax authors
d119b2f23e Merge pull request #23644 from enerrio:doc_typo
PiperOrigin-RevId: 675109955
2024-09-16 05:35:54 -07:00
jax authors
a8b996af70 Merge pull request #23642 from mattjj:tweak-error-logic
PiperOrigin-RevId: 674973797
2024-09-15 19:14:53 -07:00
Jevin Jiang
839ce9a11d [Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```

Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:53:29 -07:00
jax authors
6bfa53d8c3 Update XLA dependency to use revision
af733ec6fb.

PiperOrigin-RevId: 674918791
2024-09-15 12:58:21 -07:00
rajasekharporeddy
d60371b5db Improve docs for jax.numpy: power and pow 2024-09-15 10:33:05 +05:30
enerrio
b8d135aa05 fix small typos in docs 2024-09-14 13:53:19 -07:00
Tom Augspurger
fcc8c3759d Fixed func ref in shared-computation 2024-09-14 15:22:12 -05:00
jax authors
45d448c143 Update XLA dependency to use revision
dedab4f8cf.

PiperOrigin-RevId: 674697897
2024-09-14 12:51:20 -07:00
Matthew Johnson
02bb3d1c84 tweak error logic to save a comment :) 2024-09-14 17:26:23 +00:00
George Necula
ee6f098fa9 [pallas] Clean up forward-compatibility conditionals in Pallas lowering
In cl/657184114 (July 29th) I have made some changes in error reporting for invalid block shapes, but have left behind some conditionals to ensure forward compatibility. We are now out of the forward compatibility windows, and we clean up those conditionals.

PiperOrigin-RevId: 674603915
2024-09-14 02:32:16 -07:00
jax authors
0daca46464 Update XLA dependency to use revision
32ebd694c4.

PiperOrigin-RevId: 674404604
2024-09-13 13:01:29 -07:00
jax authors
28b5dee032 Disable flaky tsan tests temporarily.
PiperOrigin-RevId: 674338720
2024-09-13 10:03:24 -07:00
Kanglan Tang
5b8d5ce342 Fix some layout test failures on gpu backend
PiperOrigin-RevId: 674336502
2024-09-13 09:57:32 -07:00
Sergei Lebedev
83bccdd289 sharding and weak_type parameters of ShapeDtypeStruct are now keyword-only
We decided not to go through a deprecation cycle for this change, because
in the vast majority of cases internally these parameters are bound via a
keyword argument anyway.

PiperOrigin-RevId: 674324964
2024-09-13 09:24:38 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Sergei Lebedev
40040e3f69 Added a new approx_math flag to Mosaic GPU params in Pallas
The flag allows to control the precision of some operations, e.g. `exp`.

PiperOrigin-RevId: 674305430
2024-09-13 08:21:07 -07:00
Sergei Lebedev
8fa0e925dd Added a docstring to dce_jaxpr
PiperOrigin-RevId: 674304558
2024-09-13 08:17:49 -07:00
Sergei Lebedev
db7484f392 Do a single mbarrier.arrive.expect_tx per fetch in Pallas Mosaic GPU
PiperOrigin-RevId: 674260767
2024-09-13 05:38:39 -07:00
Sergei Lebedev
427a490d2b Ported a few changes to FragmentArray by cperivol@
* It now supports unary negation
* and pointwise operations between scalars and FragmentedArrays

PiperOrigin-RevId: 674244294
2024-09-13 04:32:37 -07:00
Sergei Lebedev
8159d3352c Updated :gpu_test configuration
PiperOrigin-RevId: 674242448
2024-09-13 04:24:09 -07:00
jax authors
33de426c55 Merge pull request #23553 from selamw1:docstring_fromstring
PiperOrigin-RevId: 674208707
2024-09-13 02:16:50 -07:00
jax authors
9789056061 Fix a small typo for the condition of scipy.entr.
PiperOrigin-RevId: 674205855
2024-09-13 02:06:40 -07:00
jax authors
cc1e63daf5 Merge pull request #23618 from gnecula:export_doc1
PiperOrigin-RevId: 674204478
2024-09-13 02:02:55 -07:00
Sergei Lebedev
e2d7ef2a49 Pallas Mosaic GPU now supports scratch buffers in SMEM
PiperOrigin-RevId: 674173250
2024-09-13 00:09:57 -07:00
George Necula
67980d6af4 [export] Improve the forward compatibility documentation
Update the documentation to use the `LoweringRuleContext.is_forward_compat`
helper function.
2024-09-13 08:38:49 +03:00
Parker Schuh
16699952aa ParsedPartitionSpec needs to check that it is the proper instance type
before comparing for equality or it will throw an exception in the later code.

PiperOrigin-RevId: 674106064
2024-09-12 19:50:47 -07:00