Sharad Vikram
9d3762bd47
[Pallas] Add design note for async ops on TPU
2024-09-17 12:45:29 -07:00
Jevin Jiang
d27fce6981
[Pallas TPU] Fix dtype_bitwidth for int in util.
...
PiperOrigin-RevId: 675357560
2024-09-16 18:00:13 -07:00
Peter Hawkins
940860625e
Remove code that existed to support jaxlib < 0.4.32.
...
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57
PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
jax authors
df385b6ad3
Merge pull request #23652 from flferretti:feature/scalar_first_scipy_spatial
...
PiperOrigin-RevId: 675286624
2024-09-16 14:17:19 -07:00
Peter Hawkins
29163fcefd
Update XLA dependency to use revision
...
90be6e3a11
.
PiperOrigin-RevId: 675267262
2024-09-16 13:22:10 -07:00
jax authors
fc5e7b150d
Merge pull request #23672 from hawkinsp:xla
...
PiperOrigin-RevId: 675265483
2024-09-16 13:16:48 -07:00
Peter Hawkins
4fac852a90
Remove XLA tanh fix cherry-pick, to avoid CI breakages when the XLA commit is
...
bumped.
2024-09-16 20:05:25 +00:00
Filippo Luca Ferretti
2ff26ff3e0
Add scalar_first
argument to jax.scipy.spatial.transform.Rotation.as_quat
2024-09-16 21:57:55 +02:00
jax authors
8c84e1637e
Merge pull request #23670 from hawkinsp:postrelease
...
PiperOrigin-RevId: 675252501
2024-09-16 12:39:30 -07:00
Vadym Matsishevskyi
8804be0229
Add Python 3.130rc2 support to the build.
...
This PR depends on https://github.com/openxla/xla/pull/17169 . The change does not fail existing builds, but to be able to use python 3.13 functionality in jax the corresponding XLA pr needs to land first and get integrated with JAX (happens automatically).
PiperOrigin-RevId: 675243989
2024-09-16 12:14:32 -07:00
jax authors
90f532a9ac
Merge pull request #23556 from selamw1:docstring_frombuffer
...
PiperOrigin-RevId: 675240113
2024-09-16 12:03:53 -07:00
Yash Katariya
8ab66c8103
Fix the TPU and GPU nightly install instructions.
...
PiperOrigin-RevId: 675233702
2024-09-16 11:46:58 -07:00
Peter Hawkins
ae0e403c60
Merge release/0.4.33 into main and update version numbers.
2024-09-16 18:46:24 +00:00
jax authors
c60e9f08b6
Merge pull request #23665 from jakevdp:cleanup-string
...
PiperOrigin-RevId: 675229668
2024-09-16 11:37:15 -07:00
jax authors
543e802742
Merge pull request #23655 from rajasekharporeddy:testbranch3
...
PiperOrigin-RevId: 675221097
2024-09-16 11:16:24 -07:00
selamw1
7dde9b2909
frombuffer_docstring_added
...
description_changed_examp_added
doc_byte_fixed
discription_modified
2024-09-16 11:11:52 -07:00
Jake VanderPlas
321b4fbfbf
Remove unused string global
2024-09-16 11:03:11 -07:00
rajasekharporeddy
0942458d71
Improve doc for jnp.logaddexp2
2024-09-16 23:16:50 +05:30
Sergei Lebedev
8c39d0373a
Added a new primitive for copying GMEM<->SMEM in Pallas Mosaic GPU kernels
...
The copy is async and needs to be awaited via `plgpu.wait_inflight(...)` for
SMEM->GMEM copies and via `plgpu.wait(barrier)` for GMEM->SMEM copies.
I decided to have distinct functions for SMEM->GMEM and GMEM->SMEM copies
and for the ways to await the result, because the underlying Mosaic GPU
APIs (and PTX ISA) *are* in fact very different.
PiperOrigin-RevId: 675155317
2024-09-16 08:18:46 -07:00
jax authors
8a867c12f0
Merge pull request #23654 from jakevdp:atleast-nd-docs
...
PiperOrigin-RevId: 675140904
2024-09-16 07:32:19 -07:00
jax authors
dfa4e2413c
Merge pull request #23643 from TomAugspurger:fix/doc-typos
...
PiperOrigin-RevId: 675133675
2024-09-16 07:07:26 -07:00
Peter Hawkins
80e1c94de6
Prepare for v0.4.33 release.
...
This release is branched off the v0.4.32 release, with two changes:
a) a fixed libtpu pin, and
b) a patch to revert an F64 tanh issue on CPU.
2024-09-16 13:30:35 +00:00
Jake VanderPlas
d5ceb78708
Better documentation for jnp.atleast_*d
2024-09-16 06:05:19 -07:00
jax authors
a8bd6eb4d1
Merge pull request #23623 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 675115862
2024-09-16 05:59:56 -07:00
jax authors
d119b2f23e
Merge pull request #23644 from enerrio:doc_typo
...
PiperOrigin-RevId: 675109955
2024-09-16 05:35:54 -07:00
jax authors
a8b996af70
Merge pull request #23642 from mattjj:tweak-error-logic
...
PiperOrigin-RevId: 674973797
2024-09-15 19:14:53 -07:00
Jevin Jiang
839ce9a11d
[Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
...
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:53:29 -07:00
jax authors
6bfa53d8c3
Update XLA dependency to use revision
...
af733ec6fb
.
PiperOrigin-RevId: 674918791
2024-09-15 12:58:21 -07:00
rajasekharporeddy
d60371b5db
Improve docs for jax.numpy: power and pow
2024-09-15 10:33:05 +05:30
enerrio
b8d135aa05
fix small typos in docs
2024-09-14 13:53:19 -07:00
Tom Augspurger
fcc8c3759d
Fixed func ref in shared-computation
2024-09-14 15:22:12 -05:00
jax authors
45d448c143
Update XLA dependency to use revision
...
dedab4f8cf
.
PiperOrigin-RevId: 674697897
2024-09-14 12:51:20 -07:00
Matthew Johnson
02bb3d1c84
tweak error logic to save a comment :)
2024-09-14 17:26:23 +00:00
George Necula
ee6f098fa9
[pallas] Clean up forward-compatibility conditionals in Pallas lowering
...
In cl/657184114 (July 29th) I have made some changes in error reporting for invalid block shapes, but have left behind some conditionals to ensure forward compatibility. We are now out of the forward compatibility windows, and we clean up those conditionals.
PiperOrigin-RevId: 674603915
2024-09-14 02:32:16 -07:00
jax authors
0daca46464
Update XLA dependency to use revision
...
32ebd694c4
.
PiperOrigin-RevId: 674404604
2024-09-13 13:01:29 -07:00
jax authors
28b5dee032
Disable flaky tsan tests temporarily.
...
PiperOrigin-RevId: 674338720
2024-09-13 10:03:24 -07:00
Kanglan Tang
5b8d5ce342
Fix some layout test failures on gpu backend
...
PiperOrigin-RevId: 674336502
2024-09-13 09:57:32 -07:00
Sergei Lebedev
83bccdd289
sharding
and weak_type
parameters of ShapeDtypeStruct
are now keyword-only
...
We decided not to go through a deprecation cycle for this change, because
in the vast majority of cases internally these parameters are bound via a
keyword argument anyway.
PiperOrigin-RevId: 674324964
2024-09-13 09:24:38 -07:00
Sergei Lebedev
b886bd7300
Removed the named_shape
argument from jex.core.ShapedArray
and jax.ShapeDtypeStruct
...
It is unused and was only kept around to avoid breaking internal users.
PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Sergei Lebedev
40040e3f69
Added a new approx_math
flag to Mosaic GPU params in Pallas
...
The flag allows to control the precision of some operations, e.g. `exp`.
PiperOrigin-RevId: 674305430
2024-09-13 08:21:07 -07:00
Sergei Lebedev
8fa0e925dd
Added a docstring to dce_jaxpr
...
PiperOrigin-RevId: 674304558
2024-09-13 08:17:49 -07:00
Sergei Lebedev
db7484f392
Do a single mbarrier.arrive.expect_tx per fetch in Pallas Mosaic GPU
...
PiperOrigin-RevId: 674260767
2024-09-13 05:38:39 -07:00
Sergei Lebedev
427a490d2b
Ported a few changes to FragmentArray by cperivol@
...
* It now supports unary negation
* and pointwise operations between scalars and FragmentedArrays
PiperOrigin-RevId: 674244294
2024-09-13 04:32:37 -07:00
Sergei Lebedev
8159d3352c
Updated :gpu_test configuration
...
PiperOrigin-RevId: 674242448
2024-09-13 04:24:09 -07:00
jax authors
33de426c55
Merge pull request #23553 from selamw1:docstring_fromstring
...
PiperOrigin-RevId: 674208707
2024-09-13 02:16:50 -07:00
jax authors
9789056061
Fix a small typo for the condition of scipy.entr.
...
PiperOrigin-RevId: 674205855
2024-09-13 02:06:40 -07:00
jax authors
cc1e63daf5
Merge pull request #23618 from gnecula:export_doc1
...
PiperOrigin-RevId: 674204478
2024-09-13 02:02:55 -07:00
Sergei Lebedev
e2d7ef2a49
Pallas Mosaic GPU now supports scratch buffers in SMEM
...
PiperOrigin-RevId: 674173250
2024-09-13 00:09:57 -07:00
George Necula
67980d6af4
[export] Improve the forward compatibility documentation
...
Update the documentation to use the `LoweringRuleContext.is_forward_compat`
helper function.
2024-09-13 08:38:49 +03:00
Parker Schuh
16699952aa
ParsedPartitionSpec needs to check that it is the proper instance type
...
before comparing for equality or it will throw an exception in the later code.
PiperOrigin-RevId: 674106064
2024-09-12 19:50:47 -07:00