Sharad Vikram
a7a6b40b55
[Pallas] Add interpret mode support for dynamic grid
...
PiperOrigin-RevId: 603818776
2024-02-02 16:37:47 -08:00
Jieying Luo
c4b6266049
Only register profiler when the plugin is dynamically loaded.
...
Both jaxlib and plugin calls tsl::profiler::CreateProfilers, which needs to access a singleton with a lock. In the static linking case, this singleton is shared by jaxlib and plugin, and will cause deadlock.
PiperOrigin-RevId: 603777779
2024-02-02 13:48:26 -08:00
Sergei Lebedev
cda40ece87
Removed wrap_with_builder and tensor.to from the Triton compatibility layer
...
The only API using it was cast.
I also added a test which covers int1->int8 casts.
PiperOrigin-RevId: 603771797
2024-02-02 13:23:52 -08:00
Sergei Lebedev
28eff4f9b8
Migrated dot to lower directly to Triton IR
...
PiperOrigin-RevId: 603768074
2024-02-02 13:09:25 -08:00
Sergei Lebedev
5867a05cdd
Migrated store/load to lower directly to Triton IR
...
PiperOrigin-RevId: 603764118
2024-02-02 12:53:42 -08:00
Anlun Xu
16636f9c97
[jax_triton] Only use side stream to do autotuning when doing graph capture
...
When graph capture is not enabled, autotuning and kernel launch should be on the same stream to avoid race condition.
PiperOrigin-RevId: 603728867
2024-02-02 10:48:26 -08:00
Sergei Lebedev
e1ea936fc1
Added a custom lowering rule for pow which special-cases weak dtypes
...
PiperOrigin-RevId: 603635095
2024-02-02 03:06:43 -08:00
Sergei Lebedev
d9f42c56b8
Fixed a few ir.Value type-tensor dtype mismatches in the Pallas lowering code on GPU
...
PiperOrigin-RevId: 603632044
2024-02-02 02:52:08 -08:00
George Necula
fdf227e7b2
[export] Set default native serialization version to 9.
...
This version adds better support for JAX effects.
See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions .
PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
jax authors
f7cbd1c479
Update XLA dependency to use revision
...
04dc5d6cd2
.
PiperOrigin-RevId: 603567192
2024-02-01 20:37:31 -08:00
Eugene Zhulenev
28ef77dfb0
Disable JAX compilation cache for XLA:CPU
...
PiperOrigin-RevId: 603551262
2024-02-01 19:20:39 -08:00
Sharad Vikram
a41385c860
[Pallas/TPU] Allow 1-sized batch dim in vmap of dynamic grid
...
PiperOrigin-RevId: 603518847
2024-02-01 16:43:18 -08:00
Sharad Vikram
d76705da94
[Pallas/TPU] Add vmap support for dynamic grid
...
PiperOrigin-RevId: 603502393
2024-02-01 15:39:32 -08:00
jax authors
b6cda05218
Merge pull request #19622 from jakevdp:rng-uniform-doc
...
PiperOrigin-RevId: 603420400
2024-02-01 10:58:20 -08:00
jax authors
bf9e3e3054
Merge pull request #19621 from jakevdp:cholesky
...
PiperOrigin-RevId: 603416209
2024-02-01 10:44:44 -08:00
Jake VanderPlas
5041e402e1
DOC: clarify lax.rng_uniform behavior for b<=a
2024-02-01 10:20:04 -08:00
jax authors
bb5c90bdfa
Merge pull request #19592 from mattjj:jax-attrs3
...
PiperOrigin-RevId: 603401327
2024-02-01 09:58:12 -08:00
Jake VanderPlas
bce92e5f98
BUG: fix cholesky upper implementation
2024-02-01 09:18:28 -08:00
Adam Paszke
21070b24d7
Add support for dynamically computed grid bounds in Pallas kernels.
...
PiperOrigin-RevId: 603389883
2024-02-01 09:15:19 -08:00
jax authors
69e4dd41b5
Update XLA dependency to use revision
...
7e5de8712f
.
PiperOrigin-RevId: 603255973
2024-01-31 21:37:56 -08:00
Matthew Johnson
039e8dfd9f
basic scan support
...
Co-authord-by: Dougal Maclaurin <dogualm@google.com>
2024-01-31 20:50:41 -08:00
jax authors
26995ad800
Merge pull request #19580 from jakevdp:callback-error
...
PiperOrigin-RevId: 603228092
2024-01-31 19:09:09 -08:00
Jieying Luo
e1cf807513
[PJRT C API] Add an argument PJRT_Api* to register_plugin so that a plugin can be registered either with a library path, or a PJRT_Api*.
...
Statically linked plugin can only provide PJRT_Api*. This makes dynamic and static linking plugin paths more consistent, in particular for any optional feature that depends on PJRT_Api* (e.g. profiler and AOT topology).
PiperOrigin-RevId: 603210143
2024-01-31 17:42:00 -08:00
Jake VanderPlas
0af74aab98
jax.make_array_from_callback: better errors in traced context
2024-01-31 15:13:33 -08:00
jax authors
44a7d022f8
Merge pull request #19606 from jakevdp:cholesky-upper
...
PiperOrigin-RevId: 603172800
2024-01-31 15:13:02 -08:00
Jake VanderPlas
2878567d43
api_test: install ci jaxlib version
2024-01-31 14:31:12 -08:00
Jake VanderPlas
c9a700921b
jnp.linalg.cholesky: add upper argument
2024-01-31 14:16:12 -08:00
Jieying Luo
29f1d3b033
[PJRT C API] Use xla_client.generate_pjrt_gpu_plugin_options to generate options for CUDA plugin.
...
PiperOrigin-RevId: 603074180
2024-01-31 09:42:58 -08:00
Jieying Luo
cad665401e
Change indexing_test from py_test to jax_test so that it will be included when using bazel.
...
PiperOrigin-RevId: 603074045
2024-01-31 09:34:07 -08:00
jax authors
b405ce7f37
Update XLA dependency to use revision
...
492c50766d
.
PiperOrigin-RevId: 602937257
2024-01-30 22:06:20 -08:00
Ralf W. Grosse-Kunstleve
4f5e71ca5a
Add missing super().__init__()
involving types wrapped in xla/python/sharding.cc
...
This change is to unblock https://github.com/google/pywrapcc/pull/30095 .
Leaving wrapped C++ types uninitialized creates a potential for triggering undefined behavior from Python.
PiperOrigin-RevId: 602884828
2024-01-30 17:17:12 -08:00
jax authors
4393f84680
Merge pull request #19593 from trishume:patch-1
...
PiperOrigin-RevId: 602881185
2024-01-30 17:01:03 -08:00
Tristan Hume
7933acdb90
Add type annotation to pl.load
2024-01-30 16:32:29 -08:00
jax authors
af2292aa4e
Merge pull request #19591 from jakevdp:key-reuse-slice
...
PiperOrigin-RevId: 602868125
2024-01-30 16:07:43 -08:00
jax authors
80d23d64cd
Merge pull request #19566 from mattjj:attrs-aqt
...
PiperOrigin-RevId: 602864008
2024-01-30 15:51:00 -08:00
Jake VanderPlas
a4296add2d
[key reuse] simplify slice_p reuse rule
2024-01-30 15:45:40 -08:00
jax authors
cce6520dfa
Merge pull request #19587 from jakevdp:key-reuse-info
...
PiperOrigin-RevId: 602838627
2024-01-30 14:17:17 -08:00
Jake VanderPlas
1945e8dc2b
[key reuse] better info for scan & while failures
2024-01-30 13:55:47 -08:00
jax authors
5951c5f93c
Merge pull request #19583 from jakevdp:key-reuse-error
...
PiperOrigin-RevId: 602817945
2024-01-30 13:07:09 -08:00
Jake VanderPlas
a56e8e87e5
[key reuse] print signature on failure
2024-01-30 12:46:15 -08:00
jax authors
54ba49d333
Merge pull request #19574 from sboukortt:jax-101
...
PiperOrigin-RevId: 602787303
2024-01-30 11:21:35 -08:00
Sami Boukortt
222d6f29a1
Clarify a comment slightly
...
“When” makes it sound as though the error is raised at the point where
the cast is made. “If” makes it clearer that while the cast is the root
cause, the error occurs later on.
2024-01-30 17:08:09 +01:00
Tomás Longeri
ca98ed7c40
[Mosaic] In apply_vector_layout, remove old layout attribute formats
...
PiperOrigin-RevId: 602723113
2024-01-30 07:44:10 -08:00
Sergei Lebedev
9e76e380cc
Temporarily switch triton.compat to use Triton APIs for math and semantic operations
...
This is only meant as a short-term fix to unblock internal users.
PiperOrigin-RevId: 602707085
2024-01-30 06:30:22 -08:00
Goran Flegar
66308c30ad
Integrate Triton up to [9f816a7b]( 9f816a7b98
)
...
PiperOrigin-RevId: 602641874
2024-01-30 01:16:11 -08:00
Enrique Piqueras
da6fa63bf3
Add missing lowering rules.
...
PiperOrigin-RevId: 602598917
2024-01-29 21:42:38 -08:00
jax authors
d9293f8a68
Update XLA dependency to use revision
...
3582d3ef94
.
PiperOrigin-RevId: 602596006
2024-01-29 21:24:09 -08:00
Matthew Johnson
6c2d9c7e3a
add getstate/setstate in pjit transpose, for bwd pass effects
...
Co-authored-by: Roy Frostig <frostig@google.com>
2024-01-29 20:03:11 -08:00
Yash Katariya
d9122b8bac
Add sharding to ShapeDtypeStruct retured by eval_shape if jit
has out_shardings
specified
...
PiperOrigin-RevId: 602556016
2024-01-29 18:02:51 -08:00
jax authors
52b16867a5
Merge pull request #19559 from jakevdp:key-reuse-shape-poly
...
PiperOrigin-RevId: 602503831
2024-01-29 14:38:54 -08:00