24110 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
94f9a488b1 Don't override --xla_tpu_use_enhanced_launch_barrier if explicitly set 2024-11-13 22:11:39 +00:00
jax authors
4a884d4184 Update XLA dependency to use revision
2a7890387f.

PiperOrigin-RevId: 696255293
2024-11-13 13:40:49 -08:00
jax authors
14e08aa271 Merge pull request #24874 from pearu:pearu/square_p
PiperOrigin-RevId: 696251565
2024-11-13 13:29:55 -08:00
jax authors
0a755aece5 Merge pull request #22899 from trevor-m:cache
PiperOrigin-RevId: 696202017
2024-11-13 11:10:02 -08:00
jax authors
12c8c68c4a Merge pull request #24069 from sergachev:cudnn_fusion_test_a100
PiperOrigin-RevId: 696200281
2024-11-13 11:06:08 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00
Peter Buchlovsky
558ebb9fb1 Add Pallas Triton lowering for jax.lax.bitcast_convert_type.
Only handles the case where operand type and target type have the same bitwidth.

PiperOrigin-RevId: 696184251
2024-11-13 10:25:53 -08:00
Pearu Peterson
4d0a007d57 Add square_p 2024-11-13 20:14:37 +02:00
jax authors
d83517bae5 Merge pull request #24877 from dfm:cond-rep-rule
PiperOrigin-RevId: 696176870
2024-11-13 10:07:03 -08:00
jax authors
93a1f9d317 [AutoPGLE] Fix test after pjrt cache refactoring
PiperOrigin-RevId: 696156229
2024-11-13 09:03:11 -08:00
Peter Hawkins
be3c8be186 Fix bug where the Python wrapper to ParseArguments() didn't intern the static argnames strings, causing false mismatches when searching for static arguments.
Fixes https://github.com/jax-ml/jax/issues/24857

PiperOrigin-RevId: 696151287
2024-11-13 08:47:57 -08:00
Peter Hawkins
bc82203a5c Avoid using a contextmanager in Primitive.bind.
It's slightly faster to inline the context manager code into the implementation of bind.

PiperOrigin-RevId: 696142743
2024-11-13 08:20:36 -08:00
Dan Foreman-Mackey
dfabcb027d Add a shard map replication rule for cond_p. 2024-11-13 06:33:57 -08:00
jax authors
9a28b561a6 Fix parallel pgle-tests execution.
PiperOrigin-RevId: 696031645
2024-11-13 01:32:18 -08:00
Vlad Sytchenko
f2a25cc231 [XLA] Make our LLVM usage more googley
With the advent of heterogenuous compute, XLA compilation now encompasses sub-compilation for multiple devices. These all can use LLVM, but with different settings. Today this means it is possible for one XLA client to reinitialize LLVM's global state while another client is in the middle of compilation.

Add a global lock around our LLVM usage. Concurrent compilation is still allowed, as long as both invocations have the same set of options. This means from within the same client multiple compilation invocations should still be non-blocking.

PiperOrigin-RevId: 695981613
2024-11-12 22:00:38 -08:00
Nitin Srinivasan
195d407081 Add new CI scripts for running Bazel CPU presubmits
This commit introduces new CI scripts and environment files for running Bazel CPU presubmits.

* Adds a ci directory at the root of the repository to store these files.
* Environment files are located in ci/envs and define new JAXCI_ environment variables to control CI build behavior.
* The build script sources these environment files and set up the build environment before running the build commands.

PiperOrigin-RevId: 695957540
2024-11-12 20:00:35 -08:00
jax authors
ed9fdbbf0a Merge pull request #24842 from jakevdp:batched-toeplitz
PiperOrigin-RevId: 695917476
2024-11-12 16:52:25 -08:00
Dougal Maclaurin
d47e254100 Dedent your yields!
Fixes a surprising interaction between the generator system in linear_util.py
and the try/finally python context managers we use for managing tracing context.
The `finally` block wasn't always being called until garbage collection, so the
context stack pushes/pops weren't always correctly nested. Dedenting the yield
fixes this particular bug but long-term we should get rid of linear_util
altogether.

PiperOrigin-RevId: 695898528
2024-11-12 15:51:33 -08:00
Naums Mogers
c32db46e6c [Mosaic] Add parameter names to tpu.sem_signal and add tests
This CLs adds parameter names to the optional parameters of `tpu.sem_signal` -- `device_id`, `core_id` -- to remove the ambiguity upon deserialization.
Adds LIT tests of signalling on TC with parameter names.

PiperOrigin-RevId: 695875037
2024-11-12 14:37:47 -08:00
jax authors
370c4a70bb Change the assumed width of the bool packing in the early-lowering checks in pallas
PiperOrigin-RevId: 695856621
2024-11-12 13:44:49 -08:00
jax authors
c4a0369f5c Update XLA dependency to use revision
f17344020c.

PiperOrigin-RevId: 695838490
2024-11-12 12:52:22 -08:00
jax authors
8e224dbc71 Merge pull request #24792 from dfm:ffi-tutorial-outdated
PiperOrigin-RevId: 695801160
2024-11-12 11:07:12 -08:00
Sergei Lebedev
d304025a41 [mosaic_gpu] The profiler now uses FFI calls for creating events and computing elapsed time
PiperOrigin-RevId: 695798787
2024-11-12 11:01:59 -08:00
jax authors
1221da8467 [Mosaic] Fix mask creation for packed sublanes
Unaligned concat used to be f32 only, but implicitly protected via unimplemented support for multi-row-shift in sub32 types. When this was added, we started invoking unaligned concat flow w/ sub32 types, but the masking code that assumed full rows (unpacked types) was no longer sufficient - we need better granularity for these cases. This only affects sublanes, as that is where we pack, we don't have partial lanes.

This CL, as a small benefit, also adds better error messages to the ops involved in lower_to_llo.cc.

PiperOrigin-RevId: 695796095
2024-11-12 10:55:19 -08:00
James Martens
310ff7347c Change to internal dead code elimination. Now the functions in dce_rules are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation).
PiperOrigin-RevId: 695789033
2024-11-12 10:37:04 -08:00
jax authors
3a5ac487a6 Merge pull request #24806 from jakevdp:ufunc-decorator
PiperOrigin-RevId: 695776501
2024-11-12 10:04:36 -08:00
Dougal Maclaurin
64fcb9d3e9 Fix pgle profiling, broken in previous change.
PiperOrigin-RevId: 695762690
2024-11-12 09:25:27 -08:00
jax authors
b185e64a85 Merge pull request #24860 from dfm:overflow-changelog
PiperOrigin-RevId: 695762639
2024-11-12 09:23:43 -08:00
Dan Foreman-Mackey
5808170a10 Add GPU overflow bugfix (#24846) to changelog. 2024-11-12 08:57:52 -08:00
Dan Foreman-Mackey
f757054267 Update some outdated syntax in FFI tutorial. 2024-11-12 08:34:24 -08:00
Dan Foreman-Mackey
a99ccd9341 Remove GPU test with unreasonably large memory footprint.
PiperOrigin-RevId: 695717589
2024-11-12 07:02:57 -08:00
Dan Foreman-Mackey
21e98b5ce4 Fix overflow error in GPU batched linear algebra kernels.
As reported in https://github.com/jax-ml/jax/issues/24843, our LU decomposition on GPU hits overflow errors when the batch size approaches int32 max. This was caused by an issue in how we were constructing the batched pointers used by cuBLAS.

PiperOrigin-RevId: 695694648
2024-11-12 05:33:49 -08:00
Dan Foreman-Mackey
9bb6366741 Allow more output storage types for some dot algorithms.
As reported in https://github.com/jax-ml/jax/issues/24794, there were some dot products that were resulting in an unnecessary conversion. This change makes the output storage type selection more flexible.

Fixes https://github.com/jax-ml/jax/issues/24794

PiperOrigin-RevId: 695694179
2024-11-12 05:31:50 -08:00
jax authors
837bcccefa Merge pull request #24772 from dfm:ffi-call-no-canonicalize
PiperOrigin-RevId: 695692023
2024-11-12 05:20:42 -08:00
jax authors
e79eca6f63 Merge pull request #24854 from hurryabit:is_cache_used-return-bool
PiperOrigin-RevId: 695688910
2024-11-12 05:07:54 -08:00
Junwhan Ahn
2582a337a6 Explicitly raise an error if more than 65535 channels are created
`xla::HostCallbackArgInfo` uses `uint16_t` for channel ids, so we should warn users explicitly when the channel ids exceed the UINT16_MAX instead of silently wrapping around.

PiperOrigin-RevId: 695682871
2024-11-12 04:42:25 -08:00
Sergei Lebedev
15f30a9e9c [pallas:mosaic_gpu] emit_pipeline now maintains the grid indices
Previously, it was recomputing them at every loop iteration.

PiperOrigin-RevId: 695682116
2024-11-12 04:39:17 -08:00
Chris Jones
cb82609ae5 [pallas:triton] Fix reshape lowering with scalar output shape.
PiperOrigin-RevId: 695678909
2024-11-12 04:26:15 -08:00
jax authors
5ec08767a6 Merge pull request #24823 from gnecula:poly_jvp_sort
PiperOrigin-RevId: 695678140
2024-11-12 04:24:03 -08:00
jax authors
8420e22259 Merge pull request #24822 from gnecula:delete_outfeed_rewriter
PiperOrigin-RevId: 695678083
2024-11-12 04:22:03 -08:00
jax authors
4363bb65d7 Merge pull request #24770 from jakevdp:extended-device-get
PiperOrigin-RevId: 695671688
2024-11-12 03:58:23 -08:00
George Necula
fb68c97a0d [shape_poly] Fix the handling of jvp(lax.sort)
Previously, `jvp(lax.sort)` used a shape-dependent dtype, for
the types of indices (either `int32` or `int64`, depending on
the size of the dimension). For shape polymorphism, input shapes
can affect other intermediate shapes, but not `dtype`s.

In this case it is easy to just use `int46` independent of
the actual shape.
2024-11-12 03:36:05 -08:00
George Necula
c92507772c Cleanup more remnants of the jax.experimental.host_callback
Removes the outfeed rewriter mechanism and helper functions
`jaxpr_uses_outfeed`, which were used only by
`jax.experimental.host_callback`.
2024-11-12 03:27:10 -08:00
Martin Huschenbett
31e42d8e91
Make sure compilation_cache.is_cache_used always returns a bool
In some cases, `compilation_cache.is_cache_used` can reach the end of the function body without returning anything. This amounts to an implicit `return None`, which is not in line with the functions return type of `bool`. We fix this by adding a final `return False` to the function.
2024-11-12 11:44:31 +01:00
Sharad Vikram
54e72d5054 Add wraparound for 2x2x2 v5p
PiperOrigin-RevId: 695603337
2024-11-11 22:46:06 -08:00
Jevin Jiang
38d062dbee [Mosaic TPU] Support dynamic DMA and ref slice on the 2nd minor when memref is untiled
* Generalize any untiled memref to have tiling (packing, 128)
* Support dynamic index on 2nd minor.
* Support dynamic shape on 2nd minor.

PiperOrigin-RevId: 695516124
2024-11-11 16:14:27 -08:00
Jake VanderPlas
3f98c57f7b jax.scipy.linalg.toeplitz: support implicit batching 2024-11-11 15:32:43 -08:00
Dan Foreman-Mackey
478ea0dcd6 Allow 64-bit output types from ffi_call regardless of enable_x64 flag. 2024-11-11 15:01:53 -08:00
jax authors
6892e628fb Update XLA dependency to use revision
e93a258e44.

PiperOrigin-RevId: 695470898
2024-11-11 13:52:20 -08:00
jax authors
56150286d5 Merge pull request #24841 from jax-ml:dependabot/github_actions/actions/cache-4.1.2
PiperOrigin-RevId: 695460084
2024-11-11 13:18:20 -08:00