jax authors
be1f4ba380
Merge pull request #21905 from rajasekharporeddy:doc_typos
...
PiperOrigin-RevId: 644068464
2024-06-17 11:04:28 -07:00
Kyle Lucke
ebdafea9c8
Stop using xla/status.h, xla:status, and xla::Status now that xla::Status is just an alias for an absl::Status
...
PiperOrigin-RevId: 644063768
2024-06-17 10:51:55 -07:00
jax authors
21e2319e62
Merge pull request #21911 from gnecula:poly_solve_doc
...
PiperOrigin-RevId: 644062799
2024-06-17 10:51:41 -07:00
jax authors
f86cd6de56
Rewrite vector.multi_dim_reduction
with bf16 source/accumulator/output into
...
a multi_dim_reduction with f32 source/accumulator/output, where the source
and accumulator are extended and the result is truncated. This addressed 'only
32-bit reductions supported' error.
PiperOrigin-RevId: 644062786
2024-06-17 10:51:24 -07:00
jax authors
039b7c12e1
Merge pull request #21904 from tilakrayal:patch-2
...
PiperOrigin-RevId: 644061552
2024-06-17 10:48:05 -07:00
jax authors
2f22d3abfd
Merge pull request #21901 from jjyyxx:patch-1
...
PiperOrigin-RevId: 644061544
2024-06-17 10:44:50 -07:00
Sergei Lebedev
01f182e772
Use `unitialized_value
` for allocating outputs for interpreted Pallas kernels
...
PiperOrigin-RevId: 644057616
2024-06-17 10:34:38 -07:00
Junwhan Ahn
cec796f5dc
Batch pxla.shard_args
calls triggered by jax.device_put
...
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096 ) will batch device-to-device transfers across arrays in a pytree.
The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.
PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00
George Necula
b1a8c65883
[shape_poly] Add documentation for workaround with dimension parameters.
2024-06-17 20:14:20 +03:00
Sergei Lebedev
550862f8c1
Added some docs to `_hoist_consts_to_refs
`
...
I also restructured the implementation slightly, because most list allocations
were in fact unnecessary.
2024-06-17 15:33:05 +01:00
Adam Paszke
4ea73bf787
Use constant memory to pass in TMA descriptors to the kernel
...
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.
Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.
For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays
PiperOrigin-RevId: 643972675
2024-06-17 05:31:26 -07:00
Sergei Lebedev
f67f2e06ce
Fixed a `ValueError
` when a Pallas GPU kernel closed over array constants
...
The fix idea is based on the investigation by @zhixuan-lin in #21557 .
PiperOrigin-RevId: 643965836
2024-06-17 05:05:01 -07:00
Sergei Lebedev
4913fff971
Rollback #21888 , because it breaks multiple internal users
...
Reverts 193591b5c0b90ce498015b2e3d48950615253380
PiperOrigin-RevId: 643965549
2024-06-17 05:01:04 -07:00
rajasekharporeddy
b93da3873b
Fix Typos
2024-06-17 13:55:46 +05:30
tilakrayal
b461846cc0
Fixing the naming conventions in lax_numpy.py
2024-06-17 11:56:52 +05:30
jax authors
595a620804
Update XLA dependency to use revision
...
081a8b35e8
.
PiperOrigin-RevId: 643808691
2024-06-16 13:09:08 -07:00
jax authors
ae0127d696
Merge pull request #21897 from matsen:patch-1
...
PiperOrigin-RevId: 643735130
2024-06-16 02:01:42 -07:00
Yuxuan Jiang
cd23b2c82c
Fix CUDNN requirement inconsistency in installation.md
2024-06-16 14:38:05 +08:00
jax authors
546a3a60eb
Update XLA dependency to use revision
...
340fbbf18f
.
PiperOrigin-RevId: 643651407
2024-06-15 13:52:48 -07:00
jax authors
993fa6108d
Merge pull request #21880 from gnecula:doc_poly
...
PiperOrigin-RevId: 643631386
2024-06-15 11:03:18 -07:00
George Necula
b58ff2ba20
[shape_poly] Add documentation for shape polymorphism
...
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
Erick Matsen
b585eceeb0
Fixing installation docs: needed quotes; markdown fix
2024-06-15 05:12:21 -07:00
jax authors
5019167106
Further reduce the run time of pmap_test
...
PiperOrigin-RevId: 643548103
2024-06-14 23:26:14 -07:00
Benjamin Chetioui
4e748e1fda
Bump up tolerance for tests using _testQdwh
.
...
PiperOrigin-RevId: 643547107
2024-06-14 23:18:20 -07:00
jax authors
193591b5c0
Merge pull request #21888 from jakevdp:lax-mul-bool
...
PiperOrigin-RevId: 643460399
2024-06-14 14:53:49 -07:00
Jake VanderPlas
4f7cd03893
lax.mul: accept boolean inputs
2024-06-14 13:47:11 -07:00
jax authors
895b490689
Update XLA dependency to use revision
...
04deefca53
.
PiperOrigin-RevId: 643433426
2024-06-14 13:09:54 -07:00
jax authors
99775043f0
Merge pull request #21890 from hawkinsp:nightly
...
PiperOrigin-RevId: 643429331
2024-06-14 12:54:18 -07:00
Peter Hawkins
d5844b7cd3
Update nightly installation instructions to reflect new jax/jaxlib dependency structure.
2024-06-14 15:48:31 -04:00
jax authors
e15e8b81ff
Merge pull request #21856 from jakevdp:dtype-array
...
PiperOrigin-RevId: 643361223
2024-06-14 08:59:35 -07:00
Yash Katariya
8f315c212a
Merge all nightly wheels into jax_nightly_releases.html
since we have dropped monolithic jaxlib CUDA wheels
...
PiperOrigin-RevId: 643331639
2024-06-14 07:01:44 -07:00
Jake VanderPlas
9d5932a190
Deprecate passing of arrays in place of dtypes.
2024-06-14 05:40:04 -07:00
Adam Paszke
56f6e74f45
Stop using deprecated device= argument to jax.jit in effects tests
...
PiperOrigin-RevId: 643294331
2024-06-14 03:20:07 -07:00
jax authors
c2b47e3eeb
Merge pull request #21879 from gnecula:exp_fix_deprecated
...
PiperOrigin-RevId: 643269932
2024-06-14 01:25:42 -07:00
George Necula
3b5b20c9cb
[export] Fix a couple of missed deprecated fields.
...
This fixes a couple of uses in jax2tf of deprecated
Exported fields.
2024-06-14 08:45:30 +03:00
Sharad Vikram
e12656002f
[Pallas] Don't actually vmap if we're vmapping over axis size 1
...
PiperOrigin-RevId: 643209848
2024-06-13 20:24:43 -07:00
Jake VanderPlas
a92fa547a0
Re-land https://github.com/google/jax/pull/21847
...
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b
PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
06ec7d1ad5
Reduce the matrix size in testPmapMapVmapCombinations
to reduce the test run time.
...
PiperOrigin-RevId: 643166085
2024-06-13 17:02:03 -07:00
jax authors
b907242f3e
Merge pull request #21871 from jakevdp:fix-warnings
...
PiperOrigin-RevId: 643156607
2024-06-13 16:26:09 -07:00
Jake VanderPlas
33465274da
fix some additional warnings related to #21834
2024-06-13 16:06:14 -07:00
jax authors
8f5f8df112
Merge pull request #21863 from jakevdp:dep-tracer-hash
...
PiperOrigin-RevId: 643147305
2024-06-13 15:52:36 -07:00
jax authors
c839b268d2
Get rid of the is_hermitian argument for lax.qdwh. If it was known that H was also positive semi-definite, the polar decomposition would be I*H. But for indefinite H, the QDWH algorithm does not differ from the general case for Hermitian inputs.
...
PiperOrigin-RevId: 643141687
2024-06-13 15:33:49 -07:00
jax authors
c36bb59c4d
Merge pull request #21864 from superbobry:pallas
...
PiperOrigin-RevId: 643132610
2024-06-13 15:03:30 -07:00
jax authors
0bcc81ceb3
Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
...
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
Yash Katariya
4ef33fa90e
Add trace
to stages.Wrapped
and add docs for it.
...
PiperOrigin-RevId: 643128426
2024-06-13 14:49:54 -07:00
jax authors
fa5a162c5e
Merge pull request #20728 from rajasekharporeddy:test_branch2
...
PiperOrigin-RevId: 643123794
2024-06-13 14:33:54 -07:00
jax authors
46899cb509
Merge pull request #21865 from hawkinsp:docs
...
PiperOrigin-RevId: 643121720
2024-06-13 14:27:42 -07:00
jax authors
e25528d026
Update XLA dependency to use revision
...
9b3d8c7505
.
PiperOrigin-RevId: 643116343
2024-06-13 14:10:02 -07:00
Peter Hawkins
c71298c2ee
Temporarily reinstate the jax[cpu] in the install instructions.
...
This is still needed until the next release.
2024-06-13 17:00:20 -04:00
Sergei Lebedev
2466ae3e93
Added docstrings to pl.num_programs() and pl.program_id()
2024-06-13 21:57:52 +01:00