Dougal Maclaurin
48f24b6acb
Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
...
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
jax authors
8536eca46e
Update XLA dependency to use revision
...
edf18ce242
.
PiperOrigin-RevId: 691908973
2024-10-31 13:06:07 -07:00
jax authors
c758373b9c
Remove implicit sharding annotation for tpu custom call.
...
PiperOrigin-RevId: 691876343
2024-10-31 11:30:13 -07:00
Praveen Batra
8296f6e0ba
[Mosaic] Add extension files for infer/apply vector layout.
...
PiperOrigin-RevId: 691868278
2024-10-31 11:08:37 -07:00
jax authors
7ff5a4eac2
Merge pull request #24190 from dfm:ffi-examples-gpu
...
PiperOrigin-RevId: 691862136
2024-10-31 10:53:27 -07:00
Vadym Matsishevskyi
a75d94622c
Reverts 72f9a493589a1046e6927a5f16d7dc71df530743
...
PiperOrigin-RevId: 691843537
2024-10-31 10:05:22 -07:00
Praveen Batra
7d9f565647
[Mosaic] Fix some imports.
...
PiperOrigin-RevId: 691830491
2024-10-31 09:25:34 -07:00
Dan Foreman-Mackey
ce8dba98fb
Move the CUDA end-to-end example to FFI examples workflow + hosted
...
runner.
2024-10-31 12:21:51 -04:00
jax authors
8abedda8a6
Merge pull request #24480 from dfm:dot-algorithm-plugin-enable
...
PiperOrigin-RevId: 691734684
2024-10-31 03:10:23 -07:00
Dan Foreman-Mackey
52ad60521c
Run dot algorithm tests with PJRT plugin.
2024-10-31 06:01:11 -04:00
Benjamin Chetioui
c708a04c6e
[Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
...
Also start moving the existing C++ tests to Python.
PiperOrigin-RevId: 691729887
2024-10-31 02:47:30 -07:00
Sergei Lebedev
85662f6dd8
[pallas:mosaic_gpu] plgpu.copy_smem_to_gmem
no longer transparently commits SMEM
...
Users are expected to call `pltpu.commit_smem` manually instead.
PiperOrigin-RevId: 691724662
2024-10-31 02:21:10 -07:00
Dimitar (Mitko) Asenov
7d504cd95a
[MOSAIC:GPU] Extend the mosaic mlir dialect with fragmented layouts.
...
PiperOrigin-RevId: 691712579
2024-10-31 01:29:22 -07:00
jax authors
5aeffde707
[Mosaic] Extend tpu matmulop to have dimension dims. Add support for batching and simple transposition.
...
PiperOrigin-RevId: 691706218
2024-10-31 00:59:13 -07:00
Dougal Maclaurin
f355dcf34b
Remove UnshapedArray values from JAX (it remains as an abstract class).
...
Part of a plan to move away from our "abstract value" lattice to more traditional types.
PiperOrigin-RevId: 691626481
2024-10-30 18:53:51 -07:00
Yash Katariya
7f4a34e12b
Remove the variant
since sparsecore is only on v5p and it's device kind is TPU v5
.
...
PiperOrigin-RevId: 691586791
2024-10-30 16:18:54 -07:00
Jake VanderPlas
0181cb396d
Re-land #24589 with fixes to handle dtype
that is not compatible with NumPy.
...
Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.
Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d
PiperOrigin-RevId: 691568933
2024-10-30 15:13:00 -07:00
Naums Mogers
242e6634ff
[Mosaic] Add the core type enum
...
The new attribute allows differentiating compilation by target core.
PiperOrigin-RevId: 691531726
2024-10-30 13:23:34 -07:00
jax authors
af14c43893
Update XLA dependency to use revision
...
2d9d84487e
.
PiperOrigin-RevId: 691516089
2024-10-30 12:36:35 -07:00
Bart Chrzaszcz
44158ab0e4
#sdy add shardy CPU config for all JAX tests, disabling any known failing test cases.
...
Only test cases breaking on CPU are related to:
- pure callbacks
- export
- shard alike
Note that `layout_test` is broken on TPU, leaving a comment saying to enable it.
Also fixed `shard_map_test` test that was broken when running Shardy on one TPU, and `aot_test` which was breaking due to calling a different C++ StableHLO compilation function.
PiperOrigin-RevId: 691496997
2024-10-30 11:40:20 -07:00
Dougal Maclaurin
32bf19ac6f
Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
...
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.
PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
jax authors
3904ced255
[Mosaic] Test only cl - add triu test, skip bf16 due to select being native bitwidth only
...
PiperOrigin-RevId: 691477248
2024-10-30 10:48:44 -07:00
jax authors
99ea4c1a4a
[Fix] Put * packing into reshape no-op condition (Bug in my original CL)
...
PiperOrigin-RevId: 691476663
2024-10-30 10:47:23 -07:00
Sergei Lebedev
409517fcbc
[pallas:mosaic_gpu] Disabled verbose lowering errors in Mosaic GPU tests
...
PiperOrigin-RevId: 691472782
2024-10-30 10:37:32 -07:00
Sergei Lebedev
6283eab2ff
[pallas] Added a flag disabling verbose error reporting
...
PiperOrigin-RevId: 691463398
2024-10-30 10:13:22 -07:00
Nitin Srinivasan
da994d3552
Move utility functions in build.py to utils.py
...
This commit is the first step towards re-working the build CLI. It moves all the auxiliary functions used by the CLI into a separate script for easier maintenance and readability.
PiperOrigin-RevId: 691458051
2024-10-30 10:00:32 -07:00
Tzu-Wei Sung
d2f5804449
[Pallas] Add test cases for var + constant.
...
PiperOrigin-RevId: 691450143
2024-10-30 09:37:50 -07:00
Peter Hawkins
a8f44c4700
Fix a CI failure under NumPy 2.1.
...
PiperOrigin-RevId: 691428702
2024-10-30 08:30:25 -07:00
Thomas Köppe
2bed1e88e4
Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
...
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
Sergei Lebedev
2652ab5608
[mosaic_gpu] Added support for bitwise and, or and xor to FragmentedArray
...
PiperOrigin-RevId: 691411447
2024-10-30 07:30:48 -07:00
Sergei Lebedev
2b70ad30fb
Removed unused _upcast_fp16_for_computation
...
PiperOrigin-RevId: 691409888
2024-10-30 07:24:13 -07:00
Dougal Maclaurin
a45b0856c5
Relax leak checks under the jax_data_dependent_tracing_fallback flag.
...
PiperOrigin-RevId: 691409392
2024-10-30 07:22:29 -07:00
Ayaka
8f96e9082a
[Pallas TPU] Add lowerings for scalar absi
...
This PR is a follow-up of https://github.com/jax-ml/jax/pull/24504 , which adds lowerings for scalar `absf` and `rsqrt`.
PiperOrigin-RevId: 691402430
2024-10-30 06:55:34 -07:00
jax authors
dfea163526
Merge pull request #24606 from jakevdp:dep-export
...
PiperOrigin-RevId: 691398663
2024-10-30 06:41:38 -07:00
Benjamin Chetioui
15a11365e4
Change the lowering rule for jax.lax.scan
to avoid emitting a while
loop
...
when the intent is to fully unroll the loop.
PiperOrigin-RevId: 691393597
2024-10-30 06:20:39 -07:00
Jake VanderPlas
e61a20b45a
Remove deprecated jax.experimental.export module.
...
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Sergei Lebedev
f1c3109bf5
Removed mesh_utils._bounds_from_last_device
which was only used in tests
...
PiperOrigin-RevId: 691342846
2024-10-30 02:43:56 -07:00
Sergei Lebedev
bdf2ca10fc
Removed more dead code from various submodules
...
PiperOrigin-RevId: 691342832
2024-10-30 02:41:53 -07:00
Sergei Lebedev
908c8a8280
Removed unused _get_memory_space_from_ref
...
PiperOrigin-RevId: 691342830
2024-10-30 02:39:41 -07:00
Yash Katariya
e35e7f8e20
Allow sparsecore compute with T(8) layout via the layout API and compute_on
API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore')
.
...
PiperOrigin-RevId: 691225280
2024-10-29 17:58:53 -07:00
Peter Hawkins
72f9a49358
Reverts 6d8950c04f23ad15a0443006f1e5bd21bfa84156
...
PiperOrigin-RevId: 691222756
2024-10-29 17:46:55 -07:00
jax authors
249f0101b3
Use approximate cost estimates for flash attention instead of reference XLA estimates.
...
PiperOrigin-RevId: 691209201
2024-10-29 16:53:03 -07:00
Vadym Matsishevskyi
6d8950c04f
Cleanup requirements.in and test-requirements.txt
...
PiperOrigin-RevId: 691208596
2024-10-29 16:50:54 -07:00
Jake VanderPlas
b65fdcc612
pallas: remove build dependency on jax.experimental.export
...
jax.experimental.export is deprecated, and it looks like the build rule is unused.
PiperOrigin-RevId: 691205626
2024-10-29 16:41:50 -07:00
Sergei Lebedev
539c940946
Removed unused _tan_impl
...
Also removed the legacy lowering for `tan_p`.
PiperOrigin-RevId: 691195720
2024-10-29 16:09:05 -07:00
jax authors
5ad066eeaa
[TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
...
PiperOrigin-RevId: 691192121
2024-10-29 15:57:44 -07:00
jax authors
7c4cc9552c
Merge pull request #24600 from jax-ml:fix-ref-cycle-bug
...
PiperOrigin-RevId: 691158252
2024-10-29 14:14:06 -07:00
jax authors
6dd1417d4a
Merge pull request #24589 from jakevdp:device-get-key
...
PiperOrigin-RevId: 691154098
2024-10-29 14:03:18 -07:00
Dougal
80fde785f5
Fix a reference cycle bug.
...
When we use a context manager within a linear_util.transformation we should
leave the scope of the context manager before the final yield. Otherwise we
create spurious reference cycles. This was causing
CoreTest.test_reference_cycles to fail on Python 3.10 (but not 3.13 for some
reason).
2024-10-29 20:46:07 +00:00
Jake VanderPlas
b9ad519a29
Implement device_get for typed PRNG keys
2024-10-29 12:34:46 -07:00