Ben West
02f6fcb9da
Add beta function
2023-11-05 15:37:38 -08:00
jax authors
1126945da8
Update XLA dependency to use revision
...
a58070090a
.
PiperOrigin-RevId: 579592900
2023-11-05 01:58:12 -08:00
Tamás Danyluk
bfbf9e1c33
[XLA:GPU] Consider Triton for all non-pure GEMM fusions
...
This is a big step toward enabling xla_gpu_triton_gemm_any by default.
It shows about 1.05x geomean speedup on internal benchmarks (comparable to xla_gpu_triton_gemm_any=true).
PiperOrigin-RevId: 579524573
2023-11-04 16:05:19 -07:00
jax authors
dda76733e8
Update XLA dependency to use revision
...
75bd8cc701
.
PiperOrigin-RevId: 579430578
2023-11-04 03:51:20 -07:00
jax authors
d73facd629
Merge pull request #18382 from jakevdp:prngkey-errors
...
PiperOrigin-RevId: 579365703
2023-11-03 19:46:27 -07:00
jax authors
28b512a457
Merge pull request #18137 from jakevdp:jep-numpy-scipy
...
PiperOrigin-RevId: 579365178
2023-11-03 19:37:12 -07:00
Jake VanderPlas
96d9f89415
[random] better errors for unsupported operations on prng keys
2023-11-03 19:23:18 -07:00
Jake VanderPlas
d623d04172
JEP 18137: Scope of JAX NumPy & SciPy Wrappers
2023-11-03 19:19:21 -07:00
Tomás Longeri
1c1dd7c8c7
[Mosaic] Expose C API for VectorLayout, VRegDataBounds
...
This is in preparation for Python bindings
PiperOrigin-RevId: 579355000
2023-11-03 18:24:16 -07:00
jax authors
953f4670d8
Merge pull request #18376 from jakevdp:doc-fix
...
PiperOrigin-RevId: 579279144
2023-11-03 13:02:09 -07:00
jax authors
e227536fd6
In api_test.py, wait for the result in test_double_donation.
...
PiperOrigin-RevId: 579267104
2023-11-03 12:23:55 -07:00
Jake VanderPlas
7d8c358fce
Fix wording in TFDS example
2023-11-03 11:31:48 -07:00
jax authors
808289d52a
Merge pull request #18373 from jakevdp:sharding-doc
...
PiperOrigin-RevId: 579224607
2023-11-03 10:23:44 -07:00
Jake VanderPlas
cd3ea05665
Ensure sharding-related array properties are documented
2023-11-03 09:56:33 -07:00
Peter Hawkins
011d49c518
Add a test for double donation.
...
The underlying issue was fixed some time ago.
Fixes https://github.com/google/jax/issues/9635
PiperOrigin-RevId: 579170638
2023-11-03 07:03:13 -07:00
jax authors
57385cb284
Update XLA dependency to use revision
...
049a3e6caf
.
PiperOrigin-RevId: 579135741
2023-11-03 04:04:27 -07:00
jax authors
db07f40233
Fall-back to original device/backend hashing if topology-desc is unavailable.
...
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.
Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
2023-11-02 18:43:48 -07:00
jax authors
59192b0dd1
Replace gcc with clang compiler in presubmit and postsubmit CI Kokoro jobs.
...
PiperOrigin-RevId: 579032820
2023-11-02 18:01:36 -07:00
jax authors
d1a8a7876b
Merge pull request #18363 from skye:version
...
PiperOrigin-RevId: 579029941
2023-11-02 17:45:04 -07:00
Skye Wanderman-Milne
55e3072d2e
Update versions and CHANGELOG after jax 0.4.20 release
2023-11-02 16:30:56 -07:00
jax authors
62741d9744
Reverts 81ac67f38164b7626d733d081a87ff49b235b9d0
...
PiperOrigin-RevId: 579010408
2023-11-02 16:17:29 -07:00
Jieying Luo
c9db50cfd0
Enable python_callback_test for stream executor.
...
python_callback_test is supported for GPU stream executor. TPU stream executor was deprecated.
PiperOrigin-RevId: 578960299
2023-11-02 13:26:59 -07:00
Parker Schuh
c8b7c1b80b
Remove temporary flag for forcing arg tuplization of lowered functions.
...
PiperOrigin-RevId: 578910366
2023-11-02 10:53:16 -07:00
jax authors
1f6264896d
Merge pull request #18354 from jakevdp:dep-opaque
...
PiperOrigin-RevId: 578904025
2023-11-02 10:37:47 -07:00
jax authors
6d5eaa6ec3
Merge pull request #18295 from gnecula:lax_multi
...
PiperOrigin-RevId: 578892592
2023-11-02 10:07:34 -07:00
Jake VanderPlas
0111dcbda3
Finish deprecation of allow_opaque_dtype
2023-11-02 09:51:06 -07:00
jax authors
e22b8bbc0a
Merge pull request #18325 from skye:version
...
PiperOrigin-RevId: 578882609
jax-v0.4.20
jax-v0.4.20-rc
2023-11-02 09:37:54 -07:00
Skye Wanderman-Milne
6813819187
Update versions for jax + jaxlib 0.4.20 release
2023-11-02 09:34:16 -07:00
Etienne Pot
81ac67f381
Fix typing annotations for @jax.named_call
...
PiperOrigin-RevId: 578852649
2023-11-02 07:55:04 -07:00
George Necula
8feb413211
Add a lax.platform_dependent API for writing platform-dependent code.
...
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.
See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
jax authors
1c66ac532b
Update XLA dependency to use revision
...
ca31652cdb
.
PiperOrigin-RevId: 578801983
2023-11-02 04:04:23 -07:00
Reed Wanderman-Milne
d41078fb95
Properly pack and unpack int4 arrays on CPU in PJRT.
...
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.
PiperOrigin-RevId: 578692796
2023-11-01 17:39:24 -07:00
Adam Paszke
5d2896152f
[Mosaic] Add support for generalized reduction unrolling to infer_vector_layout
...
Otherwise it's inaccessible to users.
PiperOrigin-RevId: 578471167
2023-11-01 04:12:02 -07:00
Adam Paszke
f17f5492b5
Add support for dynamic indices in VMEM loads and stores
...
... at least in all but the last two dimensions, which have more stringent alignment requirements.
PiperOrigin-RevId: 578463563
2023-11-01 03:39:54 -07:00
jax authors
32a317f7a4
Update XLA dependency to use revision
...
7ab5df624f
.
PiperOrigin-RevId: 578457777
2023-11-01 03:11:12 -07:00
Tomás Longeri
caa1c39a39
[Mosaic][NFC] Remove unused parameter
...
PiperOrigin-RevId: 578434052
2023-11-01 01:22:22 -07:00
jax authors
a009f8d6c1
Pass flags from kernel into HLO backend config.
...
PiperOrigin-RevId: 578390868
2023-10-31 21:32:19 -07:00
jax authors
9f28512c4b
Merge pull request #18340 from jakevdp:keyarray-error
...
PiperOrigin-RevId: 578355093
2023-10-31 17:49:28 -07:00
Roy Frostig
16d082b002
[jex] replace extend.random.PRNGImpl
with extend.random.define_prng_impl
...
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.
PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Jake VanderPlas
a4e6b4e943
[random] add more information to KeyArray deprecation error
2023-10-31 16:54:49 -07:00
jax authors
49fedb1c52
Merge pull request #18338 from froystig:partitionable-threefry-ctx-mgr
...
PiperOrigin-RevId: 578325021
2023-10-31 15:44:35 -07:00
Roy Frostig
ed9a4c2939
add jax.threefry_partitionable
context manager
2023-10-31 13:45:55 -07:00
Roy Frostig
b22e75716f
add threefry_partitionable
config setting to thread-local JIT context
2023-10-31 13:45:49 -07:00
jax authors
29c6faa073
Merge pull request #18269 from jakevdp:new-tutorials
...
PiperOrigin-RevId: 578211023
2023-10-31 09:36:08 -07:00
jax authors
57e33dc3b5
Update XLA dependency to use revision
...
8f27d321a8
.
PiperOrigin-RevId: 578107451
2023-10-31 02:24:57 -07:00
jax authors
002e2e7993
Merge pull request #18158 from mattjj:gpu-performance-docs
...
PiperOrigin-RevId: 578041082
2023-10-30 20:57:23 -07:00
Matthew Johnson
664e834784
draft docs on gpu performance tuning
...
Co-authored-by: Tao Wang <wangtao@google.com>
2023-10-30 20:33:56 -07:00
Yash Katariya
20255dce84
Delete cached_call_jaxpr_lowerings since a more general cached_primitive_lowerings is available
...
PiperOrigin-RevId: 577993595
2023-10-30 16:38:57 -07:00
Yash Katariya
85af862efd
[Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time.
...
Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035
PiperOrigin-RevId: 577974380
2023-10-30 15:28:43 -07:00
Jake VanderPlas
344e42ba94
Documentation: add stub files for new tutorial structure
2023-10-30 13:58:29 -07:00