19208 Commits

Author SHA1 Message Date
Sharad Vikram
a7a6b40b55 [Pallas] Add interpret mode support for dynamic grid
PiperOrigin-RevId: 603818776
2024-02-02 16:37:47 -08:00
Jieying Luo
c4b6266049 Only register profiler when the plugin is dynamically loaded.
Both jaxlib and plugin calls tsl::profiler::CreateProfilers, which needs to access a singleton with a lock. In the static linking case, this singleton is shared by jaxlib and plugin, and will cause deadlock.

PiperOrigin-RevId: 603777779
2024-02-02 13:48:26 -08:00
Sergei Lebedev
cda40ece87 Removed wrap_with_builder and tensor.to from the Triton compatibility layer
The only API using it was cast.

I also added a test which covers int1->int8 casts.

PiperOrigin-RevId: 603771797
2024-02-02 13:23:52 -08:00
Sergei Lebedev
28eff4f9b8 Migrated dot to lower directly to Triton IR
PiperOrigin-RevId: 603768074
2024-02-02 13:09:25 -08:00
Sergei Lebedev
5867a05cdd Migrated store/load to lower directly to Triton IR
PiperOrigin-RevId: 603764118
2024-02-02 12:53:42 -08:00
Anlun Xu
16636f9c97 [jax_triton] Only use side stream to do autotuning when doing graph capture
When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.

PiperOrigin-RevId: 603728867
2024-02-02 10:48:26 -08:00
Sergei Lebedev
e1ea936fc1 Added a custom lowering rule for pow which special-cases weak dtypes
PiperOrigin-RevId: 603635095
2024-02-02 03:06:43 -08:00
Sergei Lebedev
d9f42c56b8 Fixed a few ir.Value type-tensor dtype mismatches in the Pallas lowering code on GPU
PiperOrigin-RevId: 603632044
2024-02-02 02:52:08 -08:00
George Necula
fdf227e7b2 [export] Set default native serialization version to 9.
This version adds better support for JAX effects.

See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.

PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
jax authors
f7cbd1c479 Update XLA dependency to use revision
04dc5d6cd2.

PiperOrigin-RevId: 603567192
2024-02-01 20:37:31 -08:00
Eugene Zhulenev
28ef77dfb0 Disable JAX compilation cache for XLA:CPU
PiperOrigin-RevId: 603551262
2024-02-01 19:20:39 -08:00
Sharad Vikram
a41385c860 [Pallas/TPU] Allow 1-sized batch dim in vmap of dynamic grid
PiperOrigin-RevId: 603518847
2024-02-01 16:43:18 -08:00
Sharad Vikram
d76705da94 [Pallas/TPU] Add vmap support for dynamic grid
PiperOrigin-RevId: 603502393
2024-02-01 15:39:32 -08:00
jax authors
b6cda05218 Merge pull request #19622 from jakevdp:rng-uniform-doc
PiperOrigin-RevId: 603420400
2024-02-01 10:58:20 -08:00
jax authors
bf9e3e3054 Merge pull request #19621 from jakevdp:cholesky
PiperOrigin-RevId: 603416209
2024-02-01 10:44:44 -08:00
Jake VanderPlas
5041e402e1 DOC: clarify lax.rng_uniform behavior for b<=a 2024-02-01 10:20:04 -08:00
jax authors
bb5c90bdfa Merge pull request #19592 from mattjj:jax-attrs3
PiperOrigin-RevId: 603401327
2024-02-01 09:58:12 -08:00
Jake VanderPlas
bce92e5f98 BUG: fix cholesky upper implementation 2024-02-01 09:18:28 -08:00
Adam Paszke
21070b24d7 Add support for dynamically computed grid bounds in Pallas kernels.
PiperOrigin-RevId: 603389883
2024-02-01 09:15:19 -08:00
jax authors
69e4dd41b5 Update XLA dependency to use revision
7e5de8712f.

PiperOrigin-RevId: 603255973
2024-01-31 21:37:56 -08:00
Matthew Johnson
039e8dfd9f basic scan support
Co-authord-by: Dougal Maclaurin <dogualm@google.com>
2024-01-31 20:50:41 -08:00
jax authors
26995ad800 Merge pull request #19580 from jakevdp:callback-error
PiperOrigin-RevId: 603228092
2024-01-31 19:09:09 -08:00
Jieying Luo
e1cf807513 [PJRT C API] Add an argument PJRT_Api* to register_plugin so that a plugin can be registered either with a library path, or a PJRT_Api*.
Statically linked plugin can only provide PJRT_Api*. This makes dynamic and static linking plugin paths more consistent, in particular for any optional feature that depends on PJRT_Api* (e.g. profiler and AOT topology).

PiperOrigin-RevId: 603210143
2024-01-31 17:42:00 -08:00
Jake VanderPlas
0af74aab98 jax.make_array_from_callback: better errors in traced context 2024-01-31 15:13:33 -08:00
jax authors
44a7d022f8 Merge pull request #19606 from jakevdp:cholesky-upper
PiperOrigin-RevId: 603172800
2024-01-31 15:13:02 -08:00
Jake VanderPlas
2878567d43 api_test: install ci jaxlib version 2024-01-31 14:31:12 -08:00
Jake VanderPlas
c9a700921b jnp.linalg.cholesky: add upper argument 2024-01-31 14:16:12 -08:00
Jieying Luo
29f1d3b033 [PJRT C API] Use xla_client.generate_pjrt_gpu_plugin_options to generate options for CUDA plugin.
PiperOrigin-RevId: 603074180
2024-01-31 09:42:58 -08:00
Jieying Luo
cad665401e Change indexing_test from py_test to jax_test so that it will be included when using bazel.
PiperOrigin-RevId: 603074045
2024-01-31 09:34:07 -08:00
jax authors
b405ce7f37 Update XLA dependency to use revision
492c50766d.

PiperOrigin-RevId: 602937257
2024-01-30 22:06:20 -08:00
Ralf W. Grosse-Kunstleve
4f5e71ca5a Add missing super().__init__() involving types wrapped in xla/python/sharding.cc
This change is to unblock https://github.com/google/pywrapcc/pull/30095.

Leaving wrapped C++ types uninitialized creates a potential for triggering undefined behavior from Python.

PiperOrigin-RevId: 602884828
2024-01-30 17:17:12 -08:00
jax authors
4393f84680 Merge pull request #19593 from trishume:patch-1
PiperOrigin-RevId: 602881185
2024-01-30 17:01:03 -08:00
Tristan Hume
7933acdb90
Add type annotation to pl.load 2024-01-30 16:32:29 -08:00
jax authors
af2292aa4e Merge pull request #19591 from jakevdp:key-reuse-slice
PiperOrigin-RevId: 602868125
2024-01-30 16:07:43 -08:00
jax authors
80d23d64cd Merge pull request #19566 from mattjj:attrs-aqt
PiperOrigin-RevId: 602864008
2024-01-30 15:51:00 -08:00
Jake VanderPlas
a4296add2d [key reuse] simplify slice_p reuse rule 2024-01-30 15:45:40 -08:00
jax authors
cce6520dfa Merge pull request #19587 from jakevdp:key-reuse-info
PiperOrigin-RevId: 602838627
2024-01-30 14:17:17 -08:00
Jake VanderPlas
1945e8dc2b [key reuse] better info for scan & while failures 2024-01-30 13:55:47 -08:00
jax authors
5951c5f93c Merge pull request #19583 from jakevdp:key-reuse-error
PiperOrigin-RevId: 602817945
2024-01-30 13:07:09 -08:00
Jake VanderPlas
a56e8e87e5 [key reuse] print signature on failure 2024-01-30 12:46:15 -08:00
jax authors
54ba49d333 Merge pull request #19574 from sboukortt:jax-101
PiperOrigin-RevId: 602787303
2024-01-30 11:21:35 -08:00
Sami Boukortt
222d6f29a1 Clarify a comment slightly
“When” makes it sound as though the error is raised at the point where
the cast is made. “If” makes it clearer that while the cast is the root
cause, the error occurs later on.
2024-01-30 17:08:09 +01:00
Tomás Longeri
ca98ed7c40 [Mosaic] In apply_vector_layout, remove old layout attribute formats
PiperOrigin-RevId: 602723113
2024-01-30 07:44:10 -08:00
Sergei Lebedev
9e76e380cc Temporarily switch triton.compat to use Triton APIs for math and semantic operations
This is only meant as a short-term fix to unblock internal users.

PiperOrigin-RevId: 602707085
2024-01-30 06:30:22 -08:00
Goran Flegar
66308c30ad Integrate Triton up to [9f816a7b](9f816a7b98)
PiperOrigin-RevId: 602641874
2024-01-30 01:16:11 -08:00
Enrique Piqueras
da6fa63bf3 Add missing lowering rules.
PiperOrigin-RevId: 602598917
2024-01-29 21:42:38 -08:00
jax authors
d9293f8a68 Update XLA dependency to use revision
3582d3ef94.

PiperOrigin-RevId: 602596006
2024-01-29 21:24:09 -08:00
Matthew Johnson
6c2d9c7e3a add getstate/setstate in pjit transpose, for bwd pass effects
Co-authored-by: Roy Frostig <frostig@google.com>
2024-01-29 20:03:11 -08:00
Yash Katariya
d9122b8bac Add sharding to ShapeDtypeStruct retured by eval_shape if jit has out_shardings specified
PiperOrigin-RevId: 602556016
2024-01-29 18:02:51 -08:00
jax authors
52b16867a5 Merge pull request #19559 from jakevdp:key-reuse-shape-poly
PiperOrigin-RevId: 602503831
2024-01-29 14:38:54 -08:00