23928 Commits

Author SHA1 Message Date
Peter Hawkins
a8f44c4700 Fix a CI failure under NumPy 2.1.
PiperOrigin-RevId: 691428702
2024-10-30 08:30:25 -07:00
Thomas Köppe
2bed1e88e4 Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
Sergei Lebedev
2652ab5608 [mosaic_gpu] Added support for bitwise and, or and xor to FragmentedArray
PiperOrigin-RevId: 691411447
2024-10-30 07:30:48 -07:00
Sergei Lebedev
2b70ad30fb Removed unused _upcast_fp16_for_computation
PiperOrigin-RevId: 691409888
2024-10-30 07:24:13 -07:00
Dougal Maclaurin
a45b0856c5 Relax leak checks under the jax_data_dependent_tracing_fallback flag.
PiperOrigin-RevId: 691409392
2024-10-30 07:22:29 -07:00
Ayaka
8f96e9082a [Pallas TPU] Add lowerings for scalar absi
This PR is a follow-up of https://github.com/jax-ml/jax/pull/24504, which adds lowerings for scalar `absf` and `rsqrt`.

PiperOrigin-RevId: 691402430
2024-10-30 06:55:34 -07:00
jax authors
dfea163526 Merge pull request #24606 from jakevdp:dep-export
PiperOrigin-RevId: 691398663
2024-10-30 06:41:38 -07:00
Benjamin Chetioui
15a11365e4 Change the lowering rule for jax.lax.scan to avoid emitting a while loop
when the intent is to fully unroll the loop.

PiperOrigin-RevId: 691393597
2024-10-30 06:20:39 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Sergei Lebedev
f1c3109bf5 Removed mesh_utils._bounds_from_last_device which was only used in tests
PiperOrigin-RevId: 691342846
2024-10-30 02:43:56 -07:00
Sergei Lebedev
bdf2ca10fc Removed more dead code from various submodules
PiperOrigin-RevId: 691342832
2024-10-30 02:41:53 -07:00
Sergei Lebedev
908c8a8280 Removed unused _get_memory_space_from_ref
PiperOrigin-RevId: 691342830
2024-10-30 02:39:41 -07:00
Yash Katariya
e35e7f8e20 Allow sparsecore compute with T(8) layout via the layout API and compute_on API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore').
PiperOrigin-RevId: 691225280
2024-10-29 17:58:53 -07:00
Peter Hawkins
72f9a49358 Reverts 6d8950c04f23ad15a0443006f1e5bd21bfa84156
PiperOrigin-RevId: 691222756
2024-10-29 17:46:55 -07:00
jax authors
249f0101b3 Use approximate cost estimates for flash attention instead of reference XLA estimates.
PiperOrigin-RevId: 691209201
2024-10-29 16:53:03 -07:00
Vadym Matsishevskyi
6d8950c04f Cleanup requirements.in and test-requirements.txt
PiperOrigin-RevId: 691208596
2024-10-29 16:50:54 -07:00
Jake VanderPlas
b65fdcc612 pallas: remove build dependency on jax.experimental.export
jax.experimental.export is deprecated, and it looks like the build rule is unused.

PiperOrigin-RevId: 691205626
2024-10-29 16:41:50 -07:00
Sergei Lebedev
539c940946 Removed unused _tan_impl
Also removed the legacy lowering for `tan_p`.

PiperOrigin-RevId: 691195720
2024-10-29 16:09:05 -07:00
jax authors
5ad066eeaa [TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
PiperOrigin-RevId: 691192121
2024-10-29 15:57:44 -07:00
jax authors
7c4cc9552c Merge pull request #24600 from jax-ml:fix-ref-cycle-bug
PiperOrigin-RevId: 691158252
2024-10-29 14:14:06 -07:00
jax authors
6dd1417d4a Merge pull request #24589 from jakevdp:device-get-key
PiperOrigin-RevId: 691154098
2024-10-29 14:03:18 -07:00
Dougal
80fde785f5 Fix a reference cycle bug.
When we use a context manager within a linear_util.transformation we should
leave the scope of the context manager before the final yield. Otherwise we
create spurious reference cycles. This was causing
CoreTest.test_reference_cycles to fail on Python 3.10 (but not 3.13 for some
reason).
2024-10-29 20:46:07 +00:00
Jake VanderPlas
b9ad519a29 Implement device_get for typed PRNG keys 2024-10-29 12:34:46 -07:00
jax authors
ecff5af095 Merge pull request #24581 from johmedr:patch-1
PiperOrigin-RevId: 691113648
2024-10-29 12:13:23 -07:00
jax authors
63e8aff268 Update XLA dependency to use revision
b5690e93ea.

PiperOrigin-RevId: 691102818
2024-10-29 11:45:45 -07:00
jax authors
f5656bcb11 Merge pull request #24510 from dfm:dot-algorithm-config
PiperOrigin-RevId: 691096482
2024-10-29 11:30:38 -07:00
Dougal Maclaurin
c36e1f7c1a Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
jax authors
c67cf51f15 Merge pull request #24580 from dfm:fix-ffi-test-segfault
PiperOrigin-RevId: 691062859
2024-10-29 10:05:47 -07:00
Johan Medrano
1667a7e6fb
Fix missing f-string format in slogdet error message 2024-10-29 15:23:53 +00:00
Dan Foreman-Mackey
03854cfce4 Allow dot algorithms in default_matmul_precision config. 2024-10-29 10:48:21 -04:00
Dan Foreman-Mackey
1785479cbd Fix segfault caused by uninitialized LAPACK in FFI test. 2024-10-29 10:41:59 -04:00
George Necula
eff6cb445b [export] Enable more cross-platform lowering tests for GPU.
Thanks to a lot of work by Dan Foreman-Mackey and others,
there has been much progress in how we lower linalg
primitives for GPU and we can now enable cross-platform
lowering tests for these primitives.

PiperOrigin-RevId: 691013252
2024-10-29 07:36:44 -07:00
jax authors
90b8ed2c2d Merge pull request #24579 from hawkinsp:prng
PiperOrigin-RevId: 691006062
2024-10-29 07:08:16 -07:00
Peter Hawkins
bee2bc443a Remove some dead code from gpu_prng.py 2024-10-29 09:29:56 -04:00
George Necula
5ccfc8d716 Reverts c3b4b76080dbedfebfed978c812338e2f680ee23
PiperOrigin-RevId: 690990311
2024-10-29 06:07:15 -07:00
Adam Paszke
8b21614973 [Pallas:MGPU] Add FlashAttention3 as an example
PiperOrigin-RevId: 690977852
2024-10-29 05:21:43 -07:00
jax authors
de68018473 [NFC][Mosaic TPU] Clarify layout comment block
PiperOrigin-RevId: 690977672
2024-10-29 05:20:08 -07:00
jax authors
eb9e362aac Merge pull request #24572 from jakevdp:fix-copyright
PiperOrigin-RevId: 690953008
2024-10-29 03:45:54 -07:00
Adam Paszke
3a87348bfc [Pallas:MGPU] Use shfl.sync after computing the warpgroup index
The shuffle is completely unnecessary, but there's some mysterious black magic pattern
patcher in ptxas that really wants us to do it. This tiny difference is what makes or
breaks a kernel: if we shuffle the warpgroup index in attention kernels, we see ~70%
utilization; if we don't we get at most ~50%...

PiperOrigin-RevId: 690928489
2024-10-29 02:04:44 -07:00
jax authors
c3b4b76080 Merge pull request #24545 from mattjj:improved-custom-gradient
PiperOrigin-RevId: 690847441
2024-10-28 20:24:28 -07:00
Jake VanderPlas
abf14323dc Adjust copyright notice.
Previously we had been pulling-in NumPy and SciPy docs at runtime, but
after the work in #21461 this is no longer the case.
2024-10-28 18:53:38 -07:00
Ayaka
a8d1048cb6 [Pallas] Add tests for jnp.logical_not
PiperOrigin-RevId: 690825419
2024-10-28 18:53:24 -07:00
Matthew Johnson
86a47a7d4e fix jax.custom_gradient to allow closing over non-autodiff tracers 2024-10-29 00:32:01 +00:00
jax authors
12d26053e3 [TPU][Mosaic] Add support for a no-op reshape where sublane_tiling = 1 and the res_tiled and src_tiled shapes both fill a full vreg (1024)
PiperOrigin-RevId: 690796348
2024-10-28 16:57:51 -07:00
jax authors
b90f5c1c1c Merge pull request #24492 from jakevdp:finalize-implements
PiperOrigin-RevId: 690776029
2024-10-28 15:49:26 -07:00
Jake VanderPlas
14030801a5 Remove obsolete implements() decorator & fix tests 2024-10-28 15:22:09 -07:00
jax authors
e82d5a973b Merge pull request #24488 from jakevdp:atan2-doc
PiperOrigin-RevId: 690757076
2024-10-28 14:46:06 -07:00
jax authors
e3bb123652 Merge pull request #24554 from jakevdp:int4-test
PiperOrigin-RevId: 690748022
2024-10-28 14:21:31 -07:00
Jake VanderPlas
20ed2f3317 Improve docs for jnp.arctan2 2024-10-28 14:17:41 -07:00
Hyeontaek Lim
77797f434d [JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00