18100 Commits

Author SHA1 Message Date
Ben West
02f6fcb9da Add beta function 2023-11-05 15:37:38 -08:00
jax authors
1126945da8 Update XLA dependency to use revision
a58070090a.

PiperOrigin-RevId: 579592900
2023-11-05 01:58:12 -08:00
Tamás Danyluk
bfbf9e1c33 [XLA:GPU] Consider Triton for all non-pure GEMM fusions
This is a big step toward enabling xla_gpu_triton_gemm_any by default.

It shows about 1.05x geomean speedup on internal benchmarks (comparable to xla_gpu_triton_gemm_any=true).

PiperOrigin-RevId: 579524573
2023-11-04 16:05:19 -07:00
jax authors
dda76733e8 Update XLA dependency to use revision
75bd8cc701.

PiperOrigin-RevId: 579430578
2023-11-04 03:51:20 -07:00
jax authors
d73facd629 Merge pull request #18382 from jakevdp:prngkey-errors
PiperOrigin-RevId: 579365703
2023-11-03 19:46:27 -07:00
jax authors
28b512a457 Merge pull request #18137 from jakevdp:jep-numpy-scipy
PiperOrigin-RevId: 579365178
2023-11-03 19:37:12 -07:00
Jake VanderPlas
96d9f89415 [random] better errors for unsupported operations on prng keys 2023-11-03 19:23:18 -07:00
Jake VanderPlas
d623d04172 JEP 18137: Scope of JAX NumPy & SciPy Wrappers 2023-11-03 19:19:21 -07:00
Tomás Longeri
1c1dd7c8c7 [Mosaic] Expose C API for VectorLayout, VRegDataBounds
This is in preparation for Python bindings

PiperOrigin-RevId: 579355000
2023-11-03 18:24:16 -07:00
jax authors
953f4670d8 Merge pull request #18376 from jakevdp:doc-fix
PiperOrigin-RevId: 579279144
2023-11-03 13:02:09 -07:00
jax authors
e227536fd6 In api_test.py, wait for the result in test_double_donation.
PiperOrigin-RevId: 579267104
2023-11-03 12:23:55 -07:00
Jake VanderPlas
7d8c358fce Fix wording in TFDS example 2023-11-03 11:31:48 -07:00
jax authors
808289d52a Merge pull request #18373 from jakevdp:sharding-doc
PiperOrigin-RevId: 579224607
2023-11-03 10:23:44 -07:00
Jake VanderPlas
cd3ea05665 Ensure sharding-related array properties are documented 2023-11-03 09:56:33 -07:00
Peter Hawkins
011d49c518 Add a test for double donation.
The underlying issue was fixed some time ago.

Fixes https://github.com/google/jax/issues/9635

PiperOrigin-RevId: 579170638
2023-11-03 07:03:13 -07:00
jax authors
57385cb284 Update XLA dependency to use revision
049a3e6caf.

PiperOrigin-RevId: 579135741
2023-11-03 04:04:27 -07:00
jax authors
db07f40233 Fall-back to original device/backend hashing if topology-desc is unavailable.
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.

Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
2023-11-02 18:43:48 -07:00
jax authors
59192b0dd1 Replace gcc with clang compiler in presubmit and postsubmit CI Kokoro jobs.
PiperOrigin-RevId: 579032820
2023-11-02 18:01:36 -07:00
jax authors
d1a8a7876b Merge pull request #18363 from skye:version
PiperOrigin-RevId: 579029941
2023-11-02 17:45:04 -07:00
Skye Wanderman-Milne
55e3072d2e Update versions and CHANGELOG after jax 0.4.20 release 2023-11-02 16:30:56 -07:00
jax authors
62741d9744 Reverts 81ac67f38164b7626d733d081a87ff49b235b9d0
PiperOrigin-RevId: 579010408
2023-11-02 16:17:29 -07:00
Jieying Luo
c9db50cfd0 Enable python_callback_test for stream executor.
python_callback_test is supported for GPU stream executor. TPU stream executor was deprecated.

PiperOrigin-RevId: 578960299
2023-11-02 13:26:59 -07:00
Parker Schuh
c8b7c1b80b Remove temporary flag for forcing arg tuplization of lowered functions.
PiperOrigin-RevId: 578910366
2023-11-02 10:53:16 -07:00
jax authors
1f6264896d Merge pull request #18354 from jakevdp:dep-opaque
PiperOrigin-RevId: 578904025
2023-11-02 10:37:47 -07:00
jax authors
6d5eaa6ec3 Merge pull request #18295 from gnecula:lax_multi
PiperOrigin-RevId: 578892592
2023-11-02 10:07:34 -07:00
Jake VanderPlas
0111dcbda3 Finish deprecation of allow_opaque_dtype 2023-11-02 09:51:06 -07:00
jax authors
e22b8bbc0a Merge pull request #18325 from skye:version
PiperOrigin-RevId: 578882609
jax-v0.4.20 jax-v0.4.20-rc
2023-11-02 09:37:54 -07:00
Skye Wanderman-Milne
6813819187 Update versions for jax + jaxlib 0.4.20 release 2023-11-02 09:34:16 -07:00
Etienne Pot
81ac67f381 Fix typing annotations for @jax.named_call
PiperOrigin-RevId: 578852649
2023-11-02 07:55:04 -07:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
jax authors
1c66ac532b Update XLA dependency to use revision
ca31652cdb.

PiperOrigin-RevId: 578801983
2023-11-02 04:04:23 -07:00
Reed Wanderman-Milne
d41078fb95 Properly pack and unpack int4 arrays on CPU in PJRT.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.

PiperOrigin-RevId: 578692796
2023-11-01 17:39:24 -07:00
Adam Paszke
5d2896152f [Mosaic] Add support for generalized reduction unrolling to infer_vector_layout
Otherwise it's inaccessible to users.

PiperOrigin-RevId: 578471167
2023-11-01 04:12:02 -07:00
Adam Paszke
f17f5492b5 Add support for dynamic indices in VMEM loads and stores
... at least in all but the last two dimensions, which have more stringent alignment requirements.

PiperOrigin-RevId: 578463563
2023-11-01 03:39:54 -07:00
jax authors
32a317f7a4 Update XLA dependency to use revision
7ab5df624f.

PiperOrigin-RevId: 578457777
2023-11-01 03:11:12 -07:00
Tomás Longeri
caa1c39a39 [Mosaic][NFC] Remove unused parameter
PiperOrigin-RevId: 578434052
2023-11-01 01:22:22 -07:00
jax authors
a009f8d6c1 Pass flags from kernel into HLO backend config.
PiperOrigin-RevId: 578390868
2023-10-31 21:32:19 -07:00
jax authors
9f28512c4b Merge pull request #18340 from jakevdp:keyarray-error
PiperOrigin-RevId: 578355093
2023-10-31 17:49:28 -07:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Jake VanderPlas
a4e6b4e943 [random] add more information to KeyArray deprecation error 2023-10-31 16:54:49 -07:00
jax authors
49fedb1c52 Merge pull request #18338 from froystig:partitionable-threefry-ctx-mgr
PiperOrigin-RevId: 578325021
2023-10-31 15:44:35 -07:00
Roy Frostig
ed9a4c2939 add jax.threefry_partitionable context manager 2023-10-31 13:45:55 -07:00
Roy Frostig
b22e75716f add threefry_partitionable config setting to thread-local JIT context 2023-10-31 13:45:49 -07:00
jax authors
29c6faa073 Merge pull request #18269 from jakevdp:new-tutorials
PiperOrigin-RevId: 578211023
2023-10-31 09:36:08 -07:00
jax authors
57e33dc3b5 Update XLA dependency to use revision
8f27d321a8.

PiperOrigin-RevId: 578107451
2023-10-31 02:24:57 -07:00
jax authors
002e2e7993 Merge pull request #18158 from mattjj:gpu-performance-docs
PiperOrigin-RevId: 578041082
2023-10-30 20:57:23 -07:00
Matthew Johnson
664e834784 draft docs on gpu performance tuning
Co-authored-by: Tao Wang <wangtao@google.com>
2023-10-30 20:33:56 -07:00
Yash Katariya
20255dce84 Delete cached_call_jaxpr_lowerings since a more general cached_primitive_lowerings is available
PiperOrigin-RevId: 577993595
2023-10-30 16:38:57 -07:00
Yash Katariya
85af862efd [Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time.
Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035

PiperOrigin-RevId: 577974380
2023-10-30 15:28:43 -07:00
Jake VanderPlas
344e42ba94 Documentation: add stub files for new tutorial structure 2023-10-30 13:58:29 -07:00