Peter Hawkins
a8f44c4700
Fix a CI failure under NumPy 2.1.
...
PiperOrigin-RevId: 691428702
2024-10-30 08:30:25 -07:00
Thomas Köppe
2bed1e88e4
Reverts 6dd1417d4a0a9ee31d8a014352b3a0fb2bcfcbaf
...
PiperOrigin-RevId: 691417832
2024-10-30 07:54:00 -07:00
Sergei Lebedev
2652ab5608
[mosaic_gpu] Added support for bitwise and, or and xor to FragmentedArray
...
PiperOrigin-RevId: 691411447
2024-10-30 07:30:48 -07:00
Sergei Lebedev
2b70ad30fb
Removed unused _upcast_fp16_for_computation
...
PiperOrigin-RevId: 691409888
2024-10-30 07:24:13 -07:00
Dougal Maclaurin
a45b0856c5
Relax leak checks under the jax_data_dependent_tracing_fallback flag.
...
PiperOrigin-RevId: 691409392
2024-10-30 07:22:29 -07:00
Ayaka
8f96e9082a
[Pallas TPU] Add lowerings for scalar absi
...
This PR is a follow-up of https://github.com/jax-ml/jax/pull/24504 , which adds lowerings for scalar `absf` and `rsqrt`.
PiperOrigin-RevId: 691402430
2024-10-30 06:55:34 -07:00
jax authors
dfea163526
Merge pull request #24606 from jakevdp:dep-export
...
PiperOrigin-RevId: 691398663
2024-10-30 06:41:38 -07:00
Benjamin Chetioui
15a11365e4
Change the lowering rule for jax.lax.scan
to avoid emitting a while
loop
...
when the intent is to fully unroll the loop.
PiperOrigin-RevId: 691393597
2024-10-30 06:20:39 -07:00
Jake VanderPlas
e61a20b45a
Remove deprecated jax.experimental.export module.
...
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Sergei Lebedev
f1c3109bf5
Removed mesh_utils._bounds_from_last_device
which was only used in tests
...
PiperOrigin-RevId: 691342846
2024-10-30 02:43:56 -07:00
Sergei Lebedev
bdf2ca10fc
Removed more dead code from various submodules
...
PiperOrigin-RevId: 691342832
2024-10-30 02:41:53 -07:00
Sergei Lebedev
908c8a8280
Removed unused _get_memory_space_from_ref
...
PiperOrigin-RevId: 691342830
2024-10-30 02:39:41 -07:00
Yash Katariya
e35e7f8e20
Allow sparsecore compute with T(8) layout via the layout API and compute_on
API. To annotate compute on sparsecore, use @compute_on('tpu_sparsecore')
.
...
PiperOrigin-RevId: 691225280
2024-10-29 17:58:53 -07:00
Peter Hawkins
72f9a49358
Reverts 6d8950c04f23ad15a0443006f1e5bd21bfa84156
...
PiperOrigin-RevId: 691222756
2024-10-29 17:46:55 -07:00
jax authors
249f0101b3
Use approximate cost estimates for flash attention instead of reference XLA estimates.
...
PiperOrigin-RevId: 691209201
2024-10-29 16:53:03 -07:00
Vadym Matsishevskyi
6d8950c04f
Cleanup requirements.in and test-requirements.txt
...
PiperOrigin-RevId: 691208596
2024-10-29 16:50:54 -07:00
Jake VanderPlas
b65fdcc612
pallas: remove build dependency on jax.experimental.export
...
jax.experimental.export is deprecated, and it looks like the build rule is unused.
PiperOrigin-RevId: 691205626
2024-10-29 16:41:50 -07:00
Sergei Lebedev
539c940946
Removed unused _tan_impl
...
Also removed the legacy lowering for `tan_p`.
PiperOrigin-RevId: 691195720
2024-10-29 16:09:05 -07:00
jax authors
5ad066eeaa
[TPU][Mosaic] Replace tpu lowering (at canonicalization) for repeat with concat (which handles far more cases)
...
PiperOrigin-RevId: 691192121
2024-10-29 15:57:44 -07:00
jax authors
7c4cc9552c
Merge pull request #24600 from jax-ml:fix-ref-cycle-bug
...
PiperOrigin-RevId: 691158252
2024-10-29 14:14:06 -07:00
jax authors
6dd1417d4a
Merge pull request #24589 from jakevdp:device-get-key
...
PiperOrigin-RevId: 691154098
2024-10-29 14:03:18 -07:00
Dougal
80fde785f5
Fix a reference cycle bug.
...
When we use a context manager within a linear_util.transformation we should
leave the scope of the context manager before the final yield. Otherwise we
create spurious reference cycles. This was causing
CoreTest.test_reference_cycles to fail on Python 3.10 (but not 3.13 for some
reason).
2024-10-29 20:46:07 +00:00
Jake VanderPlas
b9ad519a29
Implement device_get for typed PRNG keys
2024-10-29 12:34:46 -07:00
jax authors
ecff5af095
Merge pull request #24581 from johmedr:patch-1
...
PiperOrigin-RevId: 691113648
2024-10-29 12:13:23 -07:00
jax authors
63e8aff268
Update XLA dependency to use revision
...
b5690e93ea
.
PiperOrigin-RevId: 691102818
2024-10-29 11:45:45 -07:00
jax authors
f5656bcb11
Merge pull request #24510 from dfm:dot-algorithm-config
...
PiperOrigin-RevId: 691096482
2024-10-29 11:30:38 -07:00
Dougal Maclaurin
c36e1f7c1a
Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
...
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
jax authors
c67cf51f15
Merge pull request #24580 from dfm:fix-ffi-test-segfault
...
PiperOrigin-RevId: 691062859
2024-10-29 10:05:47 -07:00
Johan Medrano
1667a7e6fb
Fix missing f-string format in slogdet error message
2024-10-29 15:23:53 +00:00
Dan Foreman-Mackey
03854cfce4
Allow dot algorithms in default_matmul_precision config.
2024-10-29 10:48:21 -04:00
Dan Foreman-Mackey
1785479cbd
Fix segfault caused by uninitialized LAPACK in FFI test.
2024-10-29 10:41:59 -04:00
George Necula
eff6cb445b
[export] Enable more cross-platform lowering tests for GPU.
...
Thanks to a lot of work by Dan Foreman-Mackey and others,
there has been much progress in how we lower linalg
primitives for GPU and we can now enable cross-platform
lowering tests for these primitives.
PiperOrigin-RevId: 691013252
2024-10-29 07:36:44 -07:00
jax authors
90b8ed2c2d
Merge pull request #24579 from hawkinsp:prng
...
PiperOrigin-RevId: 691006062
2024-10-29 07:08:16 -07:00
Peter Hawkins
bee2bc443a
Remove some dead code from gpu_prng.py
2024-10-29 09:29:56 -04:00
George Necula
5ccfc8d716
Reverts c3b4b76080dbedfebfed978c812338e2f680ee23
...
PiperOrigin-RevId: 690990311
2024-10-29 06:07:15 -07:00
Adam Paszke
8b21614973
[Pallas:MGPU] Add FlashAttention3 as an example
...
PiperOrigin-RevId: 690977852
2024-10-29 05:21:43 -07:00
jax authors
de68018473
[NFC][Mosaic TPU] Clarify layout comment block
...
PiperOrigin-RevId: 690977672
2024-10-29 05:20:08 -07:00
jax authors
eb9e362aac
Merge pull request #24572 from jakevdp:fix-copyright
...
PiperOrigin-RevId: 690953008
2024-10-29 03:45:54 -07:00
Adam Paszke
3a87348bfc
[Pallas:MGPU] Use shfl.sync after computing the warpgroup index
...
The shuffle is completely unnecessary, but there's some mysterious black magic pattern
patcher in ptxas that really wants us to do it. This tiny difference is what makes or
breaks a kernel: if we shuffle the warpgroup index in attention kernels, we see ~70%
utilization; if we don't we get at most ~50%...
PiperOrigin-RevId: 690928489
2024-10-29 02:04:44 -07:00
jax authors
c3b4b76080
Merge pull request #24545 from mattjj:improved-custom-gradient
...
PiperOrigin-RevId: 690847441
2024-10-28 20:24:28 -07:00
Jake VanderPlas
abf14323dc
Adjust copyright notice.
...
Previously we had been pulling-in NumPy and SciPy docs at runtime, but
after the work in #21461 this is no longer the case.
2024-10-28 18:53:38 -07:00
Ayaka
a8d1048cb6
[Pallas] Add tests for jnp.logical_not
...
PiperOrigin-RevId: 690825419
2024-10-28 18:53:24 -07:00
Matthew Johnson
86a47a7d4e
fix jax.custom_gradient to allow closing over non-autodiff tracers
2024-10-29 00:32:01 +00:00
jax authors
12d26053e3
[TPU][Mosaic] Add support for a no-op reshape where sublane_tiling = 1 and the res_tiled and src_tiled shapes both fill a full vreg (1024)
...
PiperOrigin-RevId: 690796348
2024-10-28 16:57:51 -07:00
jax authors
b90f5c1c1c
Merge pull request #24492 from jakevdp:finalize-implements
...
PiperOrigin-RevId: 690776029
2024-10-28 15:49:26 -07:00
Jake VanderPlas
14030801a5
Remove obsolete implements() decorator & fix tests
2024-10-28 15:22:09 -07:00
jax authors
e82d5a973b
Merge pull request #24488 from jakevdp:atan2-doc
...
PiperOrigin-RevId: 690757076
2024-10-28 14:46:06 -07:00
jax authors
e3bb123652
Merge pull request #24554 from jakevdp:int4-test
...
PiperOrigin-RevId: 690748022
2024-10-28 14:21:31 -07:00
Jake VanderPlas
20ed2f3317
Improve docs for jnp.arctan2
2024-10-28 14:17:41 -07:00
Hyeontaek Lim
77797f434d
[JAX] Add the function API of jax.experimental.colocated_python
...
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.
This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.
It is currently in an early development stage. Please proceed with a caution
when using the API.
PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00