23905 Commits

Author SHA1 Message Date
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
jax authors
8536eca46e Update XLA dependency to use revision
edf18ce242.

PiperOrigin-RevId: 691908973
2024-10-31 13:06:07 -07:00
jax authors
c758373b9c Remove implicit sharding annotation for tpu custom call.
PiperOrigin-RevId: 691876343
2024-10-31 11:30:13 -07:00
Praveen Batra
8296f6e0ba [Mosaic] Add extension files for infer/apply vector layout.
PiperOrigin-RevId: 691868278
2024-10-31 11:08:37 -07:00
jax authors
7ff5a4eac2 Merge pull request #24190 from dfm:ffi-examples-gpu
PiperOrigin-RevId: 691862136
2024-10-31 10:53:27 -07:00
Vadym Matsishevskyi
a75d94622c Reverts 72f9a493589a1046e6927a5f16d7dc71df530743
PiperOrigin-RevId: 691843537
2024-10-31 10:05:22 -07:00
Praveen Batra
7d9f565647 [Mosaic] Fix some imports.
PiperOrigin-RevId: 691830491
2024-10-31 09:25:34 -07:00
Dan Foreman-Mackey
ce8dba98fb Move the CUDA end-to-end example to FFI examples workflow + hosted
runner.
2024-10-31 12:21:51 -04:00
jax authors
8abedda8a6 Merge pull request #24480 from dfm:dot-algorithm-plugin-enable
PiperOrigin-RevId: 691734684
2024-10-31 03:10:23 -07:00
Dan Foreman-Mackey
52ad60521c Run dot algorithm tests with PJRT plugin. 2024-10-31 06:01:11 -04:00
Benjamin Chetioui
c708a04c6e [Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
Also start moving the existing C++ tests to Python.

PiperOrigin-RevId: 691729887
2024-10-31 02:47:30 -07:00
Sergei Lebedev
85662f6dd8 [pallas:mosaic_gpu] plgpu.copy_smem_to_gmem no longer transparently commits SMEM
Users are expected to call `pltpu.commit_smem` manually instead.

PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Dimitar (Mitko) Asenov
7d504cd95a [MOSAIC:GPU] Extend the mosaic mlir dialect with fragmented layouts.
PiperOrigin-RevId: 691712579
2024-10-31 01:29:22 -07:00
jax authors
5aeffde707 [Mosaic] Extend tpu matmulop to have dimension dims. Add support for batching and simple transposition.
PiperOrigin-RevId: 691706218
2024-10-31 00:59:13 -07:00
Dougal Maclaurin
f355dcf34b Remove UnshapedArray values from JAX (it remains as an abstract class).
Part of a plan to move away from our "abstract value" lattice to more traditional types.

PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Yash Katariya
7f4a34e12b Remove the variant since sparsecore is only on v5p and it's device kind is TPU v5.
PiperOrigin-RevId: 691586791
2024-10-30 16:18:54 -07:00
Jake VanderPlas
0181cb396d Re-land #24589 with fixes to handle dtype that is not compatible with NumPy.
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.

Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d

PiperOrigin-RevId: 691568933
2024-10-30 15:13:00 -07:00
Naums Mogers
242e6634ff [Mosaic] Add the core type enum
The new attribute allows differentiating compilation by target core.

PiperOrigin-RevId: 691531726
2024-10-30 13:23:34 -07:00
jax authors
af14c43893 Update XLA dependency to use revision
2d9d84487e.

PiperOrigin-RevId: 691516089
2024-10-30 12:36:35 -07:00
Bart Chrzaszcz
44158ab0e4 #sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike

Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.

Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.

PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Dougal Maclaurin
32bf19ac6f Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.

PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
jax authors
3904ced255 [Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only
PiperOrigin-RevId: 691477248
2024-10-30 10:48:44 -07:00
jax authors
99ea4c1a4a [Fix] Put * packing into reshape no-op condition (Bug in my original CL)
PiperOrigin-RevId: 691476663
2024-10-30 10:47:23 -07:00
Sergei Lebedev
409517fcbc [pallas:mosaic_gpu] Disabled verbose lowering errors in Mosaic GPU tests
PiperOrigin-RevId: 691472782
2024-10-30 10:37:32 -07:00
Sergei Lebedev
6283eab2ff [pallas] Added a flag disabling verbose error reporting
PiperOrigin-RevId: 691463398
2024-10-30 10:13:22 -07:00
Nitin Srinivasan
da994d3552 Move utility functions in build.py to utils.py
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.

PiperOrigin-RevId: 691458051
2024-10-30 10:00:32 -07:00
Tzu-Wei Sung
d2f5804449 [Pallas] Add test cases for var + constant.
PiperOrigin-RevId: 691450143
2024-10-30 09:37:50 -07:00
Peter Hawkins
a8f44c4700 Fix a CI failure under NumPy 2.1.
PiperOrigin-RevId: 691428702
2024-10-30 08:30:25 -07:00
Thomas Köppe
2bed1e88e4 Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
Sergei Lebedev
2652ab5608 [mosaic_gpu] Added support for bitwise and, or and xor to FragmentedArray
PiperOrigin-RevId: 691411447
2024-10-30 07:30:48 -07:00
Sergei Lebedev
2b70ad30fb Removed unused _upcast_fp16_for_computation
PiperOrigin-RevId: 691409888
2024-10-30 07:24:13 -07:00
Dougal Maclaurin
a45b0856c5 Relax leak checks under the jax_data_dependent_tracing_fallback flag.
PiperOrigin-RevId: 691409392
2024-10-30 07:22:29 -07:00
Ayaka
8f96e9082a [Pallas TPU] Add lowerings for scalar absi
This PR is a follow-up of https://github.com/jax-ml/jax/pull/24504, which adds lowerings for scalar `absf` and `rsqrt`.

PiperOrigin-RevId: 691402430
2024-10-30 06:55:34 -07:00
jax authors
dfea163526 Merge pull request #24606 from jakevdp:dep-export
PiperOrigin-RevId: 691398663
2024-10-30 06:41:38 -07:00
Benjamin Chetioui
15a11365e4 Change the lowering rule for jax.lax.scan to avoid emitting a while loop
when the intent is to fully unroll the loop.

PiperOrigin-RevId: 691393597
2024-10-30 06:20:39 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Sergei Lebedev
f1c3109bf5 Removed mesh_utils._bounds_from_last_device which was only used in tests
PiperOrigin-RevId: 691342846
2024-10-30 02:43:56 -07:00
Sergei Lebedev
bdf2ca10fc Removed more dead code from various submodules
PiperOrigin-RevId: 691342832
2024-10-30 02:41:53 -07:00
Sergei Lebedev
908c8a8280 Removed unused _get_memory_space_from_ref
PiperOrigin-RevId: 691342830
2024-10-30 02:39:41 -07:00
Yash Katariya
e35e7f8e20 Allow sparsecore compute with T(8) layout via the layout API and compute_on API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore').
PiperOrigin-RevId: 691225280
2024-10-29 17:58:53 -07:00
Peter Hawkins
72f9a49358 Reverts 6d8950c04f23ad15a0443006f1e5bd21bfa84156
PiperOrigin-RevId: 691222756
2024-10-29 17:46:55 -07:00
jax authors
249f0101b3 Use approximate cost estimates for flash attention instead of reference XLA estimates.
PiperOrigin-RevId: 691209201
2024-10-29 16:53:03 -07:00
Vadym Matsishevskyi
6d8950c04f Cleanup requirements.in and test-requirements.txt
PiperOrigin-RevId: 691208596
2024-10-29 16:50:54 -07:00
Jake VanderPlas
b65fdcc612 pallas: remove build dependency on jax.experimental.export
jax.experimental.export is deprecated, and it looks like the build rule is unused.

PiperOrigin-RevId: 691205626
2024-10-29 16:41:50 -07:00
Sergei Lebedev
539c940946 Removed unused _tan_impl
Also removed the legacy lowering for `tan_p`.

PiperOrigin-RevId: 691195720
2024-10-29 16:09:05 -07:00
jax authors
5ad066eeaa [TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
PiperOrigin-RevId: 691192121
2024-10-29 15:57:44 -07:00
jax authors
7c4cc9552c Merge pull request #24600 from jax-ml:fix-ref-cycle-bug
PiperOrigin-RevId: 691158252
2024-10-29 14:14:06 -07:00
jax authors
6dd1417d4a Merge pull request #24589 from jakevdp:device-get-key
PiperOrigin-RevId: 691154098
2024-10-29 14:03:18 -07:00
Dougal
80fde785f5 Fix a reference cycle bug.
When we use a context manager within a linear_util.transformation we should
leave the scope of the context manager before the final yield. Otherwise we
create spurious reference cycles. This was causing
CoreTest.test_reference_cycles to fail on Python 3.10 (but not 3.13 for some
reason).
2024-10-29 20:46:07 +00:00
Jake VanderPlas
b9ad519a29 Implement device_get for typed PRNG keys 2024-10-29 12:34:46 -07:00