21466 Commits

Author SHA1 Message Date
jax authors
be1f4ba380 Merge pull request #21905 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 644068464
2024-06-17 11:04:28 -07:00
Kyle Lucke
ebdafea9c8 Stop using xla/status.h, xla:status, and xla::Status now that xla::Status is just an alias for an absl::Status
PiperOrigin-RevId: 644063768
2024-06-17 10:51:55 -07:00
jax authors
21e2319e62 Merge pull request #21911 from gnecula:poly_solve_doc
PiperOrigin-RevId: 644062799
2024-06-17 10:51:41 -07:00
jax authors
f86cd6de56 Rewrite vector.multi_dim_reduction with bf16 source/accumulator/output into
a multi_dim_reduction with f32 source/accumulator/output, where the source
and accumulator are extended and the result is truncated. This addressed 'only
32-bit reductions supported' error.

PiperOrigin-RevId: 644062786
2024-06-17 10:51:24 -07:00
jax authors
039b7c12e1 Merge pull request #21904 from tilakrayal:patch-2
PiperOrigin-RevId: 644061552
2024-06-17 10:48:05 -07:00
jax authors
2f22d3abfd Merge pull request #21901 from jjyyxx:patch-1
PiperOrigin-RevId: 644061544
2024-06-17 10:44:50 -07:00
Sergei Lebedev
01f182e772 Use `unitialized_value` for allocating outputs for interpreted Pallas kernels
PiperOrigin-RevId: 644057616
2024-06-17 10:34:38 -07:00
Junwhan Ahn
cec796f5dc Batch pxla.shard_args calls triggered by jax.device_put
With this change, one `jax.device_put` call now corresponds to one `device_put_p.bind()` instead of one per array. Immediately, this improves the performance of `jax.device_put(...)` with a large pytree by amortizing the calls to `pxla.shard_args`. Also, backends that implement efficient batch transfers (https://github.com/tensorflow/tensorflow/pull/69096) will batch device-to-device transfers across arrays in a pytree.

The api_benchmark indicates that this CL makes `device_put` with 10 to 1000 arrays ~30% faster, likely because it reduces the number of `device_put_p.bind()` calls.

PiperOrigin-RevId: 644051624
2024-06-17 10:17:25 -07:00
George Necula
b1a8c65883 [shape_poly] Add documentation for workaround with dimension parameters. 2024-06-17 20:14:20 +03:00
Sergei Lebedev
550862f8c1 Added some docs to `_hoist_consts_to_refs`
I also restructured the implementation slightly, because most list allocations
were in fact unnecessary.
2024-06-17 15:33:05 +01:00
Adam Paszke
4ea73bf787 Use constant memory to pass in TMA descriptors to the kernel
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.

Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.

For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays

PiperOrigin-RevId: 643972675
2024-06-17 05:31:26 -07:00
Sergei Lebedev
f67f2e06ce Fixed a `ValueError` when a Pallas GPU kernel closed over array constants
The fix idea is based on the investigation by @zhixuan-lin in #21557.

PiperOrigin-RevId: 643965836
2024-06-17 05:05:01 -07:00
Sergei Lebedev
4913fff971 Rollback #21888, because it breaks multiple internal users
Reverts 193591b5c0b90ce498015b2e3d48950615253380

PiperOrigin-RevId: 643965549
2024-06-17 05:01:04 -07:00
rajasekharporeddy
b93da3873b Fix Typos 2024-06-17 13:55:46 +05:30
tilakrayal
b461846cc0
Fixing the naming conventions in lax_numpy.py 2024-06-17 11:56:52 +05:30
jax authors
595a620804 Update XLA dependency to use revision
081a8b35e8.

PiperOrigin-RevId: 643808691
2024-06-16 13:09:08 -07:00
jax authors
ae0127d696 Merge pull request #21897 from matsen:patch-1
PiperOrigin-RevId: 643735130
2024-06-16 02:01:42 -07:00
Yuxuan Jiang
cd23b2c82c
Fix CUDNN requirement inconsistency in installation.md 2024-06-16 14:38:05 +08:00
jax authors
546a3a60eb Update XLA dependency to use revision
340fbbf18f.

PiperOrigin-RevId: 643651407
2024-06-15 13:52:48 -07:00
jax authors
993fa6108d Merge pull request #21880 from gnecula:doc_poly
PiperOrigin-RevId: 643631386
2024-06-15 11:03:18 -07:00
George Necula
b58ff2ba20 [shape_poly] Add documentation for shape polymorphism
This involved writing some new content and also moving and adapting
the documentation that existed as part of the jax2tf
README file:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion
2024-06-15 18:20:54 +03:00
Erick Matsen
b585eceeb0
Fixing installation docs: needed quotes; markdown fix 2024-06-15 05:12:21 -07:00
jax authors
5019167106 Further reduce the run time of pmap_test
PiperOrigin-RevId: 643548103
2024-06-14 23:26:14 -07:00
Benjamin Chetioui
4e748e1fda Bump up tolerance for tests using _testQdwh.
PiperOrigin-RevId: 643547107
2024-06-14 23:18:20 -07:00
jax authors
193591b5c0 Merge pull request #21888 from jakevdp:lax-mul-bool
PiperOrigin-RevId: 643460399
2024-06-14 14:53:49 -07:00
Jake VanderPlas
4f7cd03893 lax.mul: accept boolean inputs 2024-06-14 13:47:11 -07:00
jax authors
895b490689 Update XLA dependency to use revision
04deefca53.

PiperOrigin-RevId: 643433426
2024-06-14 13:09:54 -07:00
jax authors
99775043f0 Merge pull request #21890 from hawkinsp:nightly
PiperOrigin-RevId: 643429331
2024-06-14 12:54:18 -07:00
Peter Hawkins
d5844b7cd3 Update nightly installation instructions to reflect new jax/jaxlib dependency structure. 2024-06-14 15:48:31 -04:00
jax authors
e15e8b81ff Merge pull request #21856 from jakevdp:dtype-array
PiperOrigin-RevId: 643361223
2024-06-14 08:59:35 -07:00
Yash Katariya
8f315c212a Merge all nightly wheels into jax_nightly_releases.html since we have dropped monolithic jaxlib CUDA wheels
PiperOrigin-RevId: 643331639
2024-06-14 07:01:44 -07:00
Jake VanderPlas
9d5932a190 Deprecate passing of arrays in place of dtypes. 2024-06-14 05:40:04 -07:00
Adam Paszke
56f6e74f45 Stop using deprecated device= argument to jax.jit in effects tests
PiperOrigin-RevId: 643294331
2024-06-14 03:20:07 -07:00
jax authors
c2b47e3eeb Merge pull request #21879 from gnecula:exp_fix_deprecated
PiperOrigin-RevId: 643269932
2024-06-14 01:25:42 -07:00
George Necula
3b5b20c9cb [export] Fix a couple of missed deprecated fields.
This fixes a couple of uses in jax2tf of deprecated
Exported fields.
2024-06-14 08:45:30 +03:00
Sharad Vikram
e12656002f [Pallas] Don't actually vmap if we're vmapping over axis size 1
PiperOrigin-RevId: 643209848
2024-06-13 20:24:43 -07:00
Jake VanderPlas
a92fa547a0 Re-land https://github.com/google/jax/pull/21847
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
06ec7d1ad5 Reduce the matrix size in testPmapMapVmapCombinations to reduce the test run time.
PiperOrigin-RevId: 643166085
2024-06-13 17:02:03 -07:00
jax authors
b907242f3e Merge pull request #21871 from jakevdp:fix-warnings
PiperOrigin-RevId: 643156607
2024-06-13 16:26:09 -07:00
Jake VanderPlas
33465274da fix some additional warnings related to #21834 2024-06-13 16:06:14 -07:00
jax authors
8f5f8df112 Merge pull request #21863 from jakevdp:dep-tracer-hash
PiperOrigin-RevId: 643147305
2024-06-13 15:52:36 -07:00
jax authors
c839b268d2 Get rid of the is_hermitian argument for lax.qdwh. If it was known that H was also positive semi-definite, the polar decomposition would be I*H. But for indefinite H, the QDWH algorithm does not differ from the general case for Hermitian inputs.
PiperOrigin-RevId: 643141687
2024-06-13 15:33:49 -07:00
jax authors
c36bb59c4d Merge pull request #21864 from superbobry:pallas
PiperOrigin-RevId: 643132610
2024-06-13 15:03:30 -07:00
jax authors
0bcc81ceb3 Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
Yash Katariya
4ef33fa90e Add trace to stages.Wrapped and add docs for it.
PiperOrigin-RevId: 643128426
2024-06-13 14:49:54 -07:00
jax authors
fa5a162c5e Merge pull request #20728 from rajasekharporeddy:test_branch2
PiperOrigin-RevId: 643123794
2024-06-13 14:33:54 -07:00
jax authors
46899cb509 Merge pull request #21865 from hawkinsp:docs
PiperOrigin-RevId: 643121720
2024-06-13 14:27:42 -07:00
jax authors
e25528d026 Update XLA dependency to use revision
9b3d8c7505.

PiperOrigin-RevId: 643116343
2024-06-13 14:10:02 -07:00
Peter Hawkins
c71298c2ee Temporarily reinstate the jax[cpu] in the install instructions.
This is still needed until the next release.
2024-06-13 17:00:20 -04:00
Sergei Lebedev
2466ae3e93 Added docstrings to pl.num_programs() and pl.program_id() 2024-06-13 21:57:52 +01:00