24281 Commits

Author SHA1 Message Date
Frederic Bastien
a13b618c98 Document cudaMallocAsync as an experimental feature. 2024-12-06 13:41:35 -05:00
jax authors
e707edeafa Merge pull request #25034 from gnecula:poly_state
PiperOrigin-RevId: 698820458
2024-11-21 09:57:55 -08:00
Peter Buchlovsky
2178ed2fa4 [pallas] Add more test cases for Triton bitcast_convert_type lowering rule.
PiperOrigin-RevId: 698818103
2024-11-21 09:52:04 -08:00
Christos Perivolaropoulos
1d2dc17e5f [mgpu] Pointwise op can handle LHS splats.
PiperOrigin-RevId: 698818035
2024-11-21 09:50:21 -08:00
jax authors
b1b1ad622e Merge pull request #25018 from jakevdp:update-array-api
PiperOrigin-RevId: 698811575
2024-11-21 09:32:33 -08:00
Nitin Srinivasan
1e6654a031 Fix cron schedule to run past minute 0 every 2nd hour
In the previous schedule, we were running at every minute at every 2nd hour.

PiperOrigin-RevId: 698804124
2024-11-21 09:09:14 -08:00
jax authors
73352677f3 Merge pull request #25015 from barnesjoseph:add-google-sans
PiperOrigin-RevId: 698798602
2024-11-21 08:52:56 -08:00
jax authors
bf0150bb22 [JAX] Ignore xla_gpu_experimental_autotune_cache_mode when calculating module hash.
PiperOrigin-RevId: 698789020
2024-11-21 08:21:32 -08:00
Nitin Srinivasan
7d7a0fa249 Run the TPU workflow on new self-hosted runners
We are not able to run the TPU workflows because of no active runners (https://github.com/jax-ml/jax/actions/runs/11879479226/job/33101456081). So this adds the new self-hosted runners to the TPU workflow to fix this issue. The v3 type is disabled as we do not have that available yet.

PiperOrigin-RevId: 698772505
2024-11-21 07:26:05 -08:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Mikhail Goncharov
1bc9df429d Integrate LLVM at llvm/llvm-project@33fcd6acc7
Updates LLVM usage to match
[33fcd6acc755](https://github.com/llvm/llvm-project/commit/33fcd6acc755)

PiperOrigin-RevId: 698742870
2024-11-21 05:25:33 -08:00
Sergei Lebedev
f18df8f39c [pallas:mosaic_gpu] Pulled delay_release into emit_pipeline
The implementation exactly matches the one we have in the lowering.

PiperOrigin-RevId: 698713343
2024-11-21 03:13:33 -08:00
Naums Mogers
e72b449089 Reverts c04aec9d525dd2e767495e41b98e82dd79315f37
PiperOrigin-RevId: 698654038
2024-11-20 22:45:46 -08:00
Yash Katariya
6568713a04 [sharding_in_types] Add concatenate_p support
PiperOrigin-RevId: 698621325
2024-11-20 20:12:44 -08:00
Jevin Jiang
869a53345d [Mosaic TPU] Add bound check for general vector store op.
PiperOrigin-RevId: 698577015
2024-11-20 17:28:04 -08:00
Yash Katariya
840cf3f7d2 [sharding_in_types] Add pad_p support to sharding_in_types to handle transpose to slice correctly.
PiperOrigin-RevId: 698573396
2024-11-20 17:14:22 -08:00
Justin Fu
1f6152d11e [Pallas] Use Pallas cost estimator for flash attention.
PiperOrigin-RevId: 698573265
2024-11-20 17:12:37 -08:00
jax authors
f39392eaf4 Merge pull request #25020 from jakevdp:lax-pad-validation
PiperOrigin-RevId: 698568589
2024-11-20 16:55:39 -08:00
barnesjoseph
bf7f9aa8f2 Adds Google Sans font 2024-11-20 16:44:41 -08:00
Jake VanderPlas
17825882d2 jax.lax.pad: improve input validation 2024-11-20 16:21:45 -08:00
jax authors
334bd4d0ba Merge pull request #25019 from jakevdp:lax-pad-doc
PiperOrigin-RevId: 698556681
2024-11-20 16:16:10 -08:00
Jake VanderPlas
2699e9507e DOC: add examples for jax.lax.pad 2024-11-20 15:13:14 -08:00
Jake VanderPlas
f749fca760 [array api] use most recent version of array_api_tests 2024-11-20 14:50:06 -08:00
jax authors
6fe78042b5 Update XLA dependency to use revision
e763f8875b.

PiperOrigin-RevId: 698525361
2024-11-20 14:39:00 -08:00
Yash Katariya
9b94180846 [sharding_in_types] Add slice_p and squeeze_p sharding rule to make flash attention work in backward pass
For `slice_p`'s sharding rule, I error out if the operand dim is sharded and the output dim is not divisible by that axis size.

I am working on a design to make JAX support uneven sharding at the top level after which slice_p's sharding rule can just `return operand.sharding`. Another option is to add `out_sharding` to `slice` but after uneven sharding support lands, it won't be necessary.

PiperOrigin-RevId: 698522980
2024-11-20 14:31:07 -08:00
jax authors
d219439d5b Merge pull request #25011 from jakevdp:jnp-linalg-module
PiperOrigin-RevId: 698517512
2024-11-20 14:13:40 -08:00
Jevin Jiang
9d2f62f811 [Pallas TPU] Support masked store
PiperOrigin-RevId: 698514079
2024-11-20 14:03:56 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
jax authors
19b4996e5e Merge pull request #25013 from hawkinsp:relnotes
PiperOrigin-RevId: 698481034
2024-11-20 12:29:24 -08:00
Peter Hawkins
dfe27a1682 Mention stackless in the release notes. 2024-11-20 14:53:52 -05:00
jax authors
1a3e693ad5 Merge pull request #25008 from skye:barrier
PiperOrigin-RevId: 698461687
2024-11-20 11:34:35 -08:00
Jake VanderPlas
621e39de27 Set __module__ attribute of jax.numpy.linalg APIs 2024-11-20 10:47:23 -08:00
Sergei Lebedev
9584ee3bb9 [pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test
Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing
indexing at all!

PiperOrigin-RevId: 698442820
2024-11-20 10:42:02 -08:00
Parker Schuh
2c9b917b9d Don't psum over auto mesh dims in _unmentioned2.
PiperOrigin-RevId: 698440525
2024-11-20 10:36:03 -08:00
jax authors
eab9026c14 Merge pull request #25004 from jax-ml:linearize-trace
PiperOrigin-RevId: 698438212
2024-11-20 10:29:29 -08:00
Dougal
d0f17c0c04 Make a direct linearize trace.
This is an alternative to doing JVP followed by partial eval. The linearize
trace has two parent traces, one for the primal computation and one for the
tangent computation. If we make the tangent trace a DynamicJaxprTrace then we
get staged linearization. If we make it the same as the primal trace then we get
primal and tangent computations occurring in step (JVP). This is a neat trick
enabled by stackless which now lives up to its name. With two parent traces we
have a tree of traces not a linked list stack.

Primitive ops can have their own linearization rules but as a fallback we can
derive a linearization rule for a single op using jvp/partial-eval.

For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can
make this the default for linearize/grad. It should help with remat and AD
through state which are awkward to express via partial eval.
2024-11-20 10:03:00 -08:00
Christos Perivolaropoulos
8d84f28373 [pallas mgpu] Lowering for while loops as long as they are secretly for loops.
PiperOrigin-RevId: 698427307
2024-11-20 10:00:14 -08:00
Skye Wanderman-Milne
6222592625 Fix KeyError recently introduced in cloud_tpu_init.py
This fixes a bug introduced in https://github.com/jax-ml/jax/pull/24889
2024-11-20 17:46:06 +00:00
jax authors
439d34da15 Merge pull request #25005 from jakevdp:py313
PiperOrigin-RevId: 698413430
2024-11-20 09:15:03 -08:00
jax authors
800add2a03 Merge pull request #25007 from jakevdp:deps
PiperOrigin-RevId: 698413340
2024-11-20 09:13:05 -08:00
Jake VanderPlas
85e2969aea Deprecate several private APIs in jax.lib 2024-11-20 08:48:26 -08:00
Chris Jones
1e9e85a39e Simplify handling of DotAlgorithmPreset output types.
Create a clear distinction between the type used for accumulation and possible output types.

PiperOrigin-RevId: 698399447
2024-11-20 08:26:44 -08:00
Jake VanderPlas
a4266b5e31 Mention python 3.13 in docs & package metadata 2024-11-20 08:23:19 -08:00
jax authors
a582df0297 Update XLA dependency to use revision
fcee07f619.

PiperOrigin-RevId: 698371906
2024-11-20 06:39:06 -08:00
Sergei Lebedev
1df4b5f798 [pallas] Do not skip vmap tests on GPU when x64 is enabled
PiperOrigin-RevId: 698351984
2024-11-20 05:08:23 -08:00
Sergei Lebedev
04e4c69f7f [mosaic_gpu] Handle older jaxlibs in the profiler module
`measure` now raises a `RuntimeError` if the available `jaxlib` does not have
the required custom calls.

PiperOrigin-RevId: 698351662
2024-11-20 05:06:24 -08:00
Sergei Lebedev
f442d40f92 [mosaic_gpu] Fixed FragmentedArray comparisons with literals
PiperOrigin-RevId: 698343858
2024-11-20 04:31:28 -08:00
Sergei Lebedev
c76e5fe9a0 [pallas:mosaic_gpu] copy_smem_to_gmem now supports wait_read_only
PiperOrigin-RevId: 698343812
2024-11-20 04:29:33 -08:00
Peter Buchlovsky
14da7ebb76 [pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type.
Only handles the case where operand type and target type have the same bitwidth.

PiperOrigin-RevId: 698332564
2024-11-20 03:41:19 -08:00
Peter Buchlovsky
1afb05e2e2 [mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise.
Only propagate signedness from operands when the output type of `op` is an `ir.IntegerType`.

PiperOrigin-RevId: 698324596
2024-11-20 03:01:48 -08:00