24251 Commits

Author SHA1 Message Date
Peter Hawkins
dfe27a1682 Mention stackless in the release notes. 2024-11-20 14:53:52 -05:00
jax authors
1a3e693ad5 Merge pull request #25008 from skye:barrier
PiperOrigin-RevId: 698461687
2024-11-20 11:34:35 -08:00
Sergei Lebedev
9584ee3bb9 [pallas:mosaic_gpu] Avoid using multiple indexers in the parallel grid test
Turns out we can mix parallel grid with `plgpu.emit_pipeline` without doing
indexing at all!

PiperOrigin-RevId: 698442820
2024-11-20 10:42:02 -08:00
Parker Schuh
2c9b917b9d Don't psum over auto mesh dims in _unmentioned2.
PiperOrigin-RevId: 698440525
2024-11-20 10:36:03 -08:00
jax authors
eab9026c14 Merge pull request #25004 from jax-ml:linearize-trace
PiperOrigin-RevId: 698438212
2024-11-20 10:29:29 -08:00
Dougal
d0f17c0c04 Make a direct linearize trace.
This is an alternative to doing JVP followed by partial eval. The linearize
trace has two parent traces, one for the primal computation and one for the
tangent computation. If we make the tangent trace a DynamicJaxprTrace then we
get staged linearization. If we make it the same as the primal trace then we get
primal and tangent computations occurring in step (JVP). This is a neat trick
enabled by stackless which now lives up to its name. With two parent traces we
have a tree of traces not a linked list stack.

Primitive ops can have their own linearization rules but as a fallback we can
derive a linearization rule for a single op using jvp/partial-eval.

For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can
make this the default for linearize/grad. It should help with remat and AD
through state which are awkward to express via partial eval.
2024-11-20 10:03:00 -08:00
Christos Perivolaropoulos
8d84f28373 [pallas mgpu] Lowering for while loops as long as they are secretly for loops.
PiperOrigin-RevId: 698427307
2024-11-20 10:00:14 -08:00
Skye Wanderman-Milne
6222592625 Fix KeyError recently introduced in cloud_tpu_init.py
This fixes a bug introduced in https://github.com/jax-ml/jax/pull/24889
2024-11-20 17:46:06 +00:00
jax authors
439d34da15 Merge pull request #25005 from jakevdp:py313
PiperOrigin-RevId: 698413430
2024-11-20 09:15:03 -08:00
jax authors
800add2a03 Merge pull request #25007 from jakevdp:deps
PiperOrigin-RevId: 698413340
2024-11-20 09:13:05 -08:00
Jake VanderPlas
85e2969aea Deprecate several private APIs in jax.lib 2024-11-20 08:48:26 -08:00
Chris Jones
1e9e85a39e Simplify handling of DotAlgorithmPreset output types.
Create a clear distinction between the type used for accumulation and possible output types.

PiperOrigin-RevId: 698399447
2024-11-20 08:26:44 -08:00
Jake VanderPlas
a4266b5e31 Mention python 3.13 in docs & package metadata 2024-11-20 08:23:19 -08:00
jax authors
a582df0297 Update XLA dependency to use revision
fcee07f619.

PiperOrigin-RevId: 698371906
2024-11-20 06:39:06 -08:00
Sergei Lebedev
1df4b5f798 [pallas] Do not skip vmap tests on GPU when x64 is enabled
PiperOrigin-RevId: 698351984
2024-11-20 05:08:23 -08:00
Sergei Lebedev
04e4c69f7f [mosaic_gpu] Handle older jaxlibs in the profiler module
`measure` now raises a `RuntimeError` if the available `jaxlib` does not have
the required custom calls.

PiperOrigin-RevId: 698351662
2024-11-20 05:06:24 -08:00
Sergei Lebedev
f442d40f92 [mosaic_gpu] Fixed FragmentedArray comparisons with literals
PiperOrigin-RevId: 698343858
2024-11-20 04:31:28 -08:00
Sergei Lebedev
c76e5fe9a0 [pallas:mosaic_gpu] copy_smem_to_gmem now supports wait_read_only
PiperOrigin-RevId: 698343812
2024-11-20 04:29:33 -08:00
Peter Buchlovsky
14da7ebb76 [pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type.
Only handles the case where operand type and target type have the same bitwidth.

PiperOrigin-RevId: 698332564
2024-11-20 03:41:19 -08:00
Peter Buchlovsky
1afb05e2e2 [mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise.
Only propagate signedness from operands when the output type of `op` is an `ir.IntegerType`.

PiperOrigin-RevId: 698324596
2024-11-20 03:01:48 -08:00
jax authors
ae46b7564e Merge pull request #24593 from froystig:random-dtypes
PiperOrigin-RevId: 698268678
2024-11-19 23:04:06 -08:00
jax authors
4d60db1741 Add test_compute_on_host_shared_sharding in memories_test
PiperOrigin-RevId: 698250352
2024-11-19 21:33:27 -08:00
Roy Frostig
4bb81075bc represent random.key_impl of builtin RNGs by canonical string name
We do not have great reason to return specs here, and sticking to
strings instead can help with simple serialization.
2024-11-19 20:58:10 -08:00
Naums Mogers
6c291d67b7 [Mosaic] Add tpu.log verification on SC
Guards against using formatting and targeting vector subcores on SC.

PiperOrigin-RevId: 698222100
2024-11-19 19:04:29 -08:00
Peter Hawkins
867a36189b Fix a bug where constant deduplication used an inappropriate inequality.
We need to compare constants for bitwise equality, not, e.g., floating point equality. The change that added deduplication caused us to conflate +0.0 and -0.0, which led a downstream test not to terminate.

PiperOrigin-RevId: 698221147
2024-11-19 18:59:49 -08:00
Jake VanderPlas
8c71d1ad6d Make deprecated jax.experimental.array_api module visibility internal-only
This is in preparation for the module to be removed.

PiperOrigin-RevId: 698215225
2024-11-19 18:33:07 -08:00
Naums Mogers
c04aec9d52 [Mosaic] Extend tpu.sem_signal with subcore_id
This change:
- Bumps up the version of Mosaic to 4 in `serde.cc`.

- Adds optional `subcore_id` parameter to `tpu.sem_signal` for signalling specific subcores.

- Extends deserialization to correctly parse the older versions of Mosaic without the new parameter `subcore_id` of `tpu.sem_signal`.

PiperOrigin-RevId: 698163836
2024-11-19 15:22:59 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Mason Chang
42fbd301fc Move JAX example to public XLA:CPU API
PiperOrigin-RevId: 698143471
2024-11-19 14:19:29 -08:00
jax authors
3161a28424 Update XLA dependency to use revision
229f376e04.

PiperOrigin-RevId: 698136955
2024-11-19 14:01:25 -08:00
Naums Mogers
0d36b0b433 [Mosaic] Add target core type parameter to tpu.sem_signal
Adds the optional core type parameter to `tpu.sem_signal` for cross-core signalling.
If the target core type is not provided, the target core type is assumed to be that of the core issuing the signal.
The issuing core type is determined based on the core type annotation of the parent function; if the annotation is not provided, the issuing core type is assumed to be TensorCore.

PiperOrigin-RevId: 698129842
2024-11-19 13:40:13 -08:00
Sergei Lebedev
1bf70fbbc4 [pallas:mosaic_gpu] copy_gmem_to_smem no longer requires barrier to be a keyword argument
... because there really isn't any reason to require that.

PiperOrigin-RevId: 698116984
2024-11-19 13:02:35 -08:00
jax authors
2075b091c4 Merge pull request #24970 from hawkinsp:split
PiperOrigin-RevId: 698112383
2024-11-19 12:48:55 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
Dan Foreman-Mackey
a59bbb7cd7 Add test utility for accessing jaxlib version tuple.
We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests.

PiperOrigin-RevId: 698097487
2024-11-19 12:00:32 -08:00
Justin Fu
c44f11d15e Add alternate implementation of threefry as a pallas kernel.
Current restrictions:
1) Dynamic grid sizes are not supported yet. This could in theory allow us to not recompile the kernel for different shapes.
2) fold_in and split still use the original rules. But there isn't a huge benefit to using the kernel right now since the input is so small and we can't avoid re-compilation due to (1).
3) Currently doesn't support high bits on the counter, meaning we can generate at max 4B numbers in one call. This is a fringe use-case since we only support 32-bit, and generating 4B 32-bit numbers would consume 16GB of HBM (an entire TPU v5p worth of HBM).

PiperOrigin-RevId: 698086352
2024-11-19 11:26:30 -08:00
Jevin Jiang
6c31efa3f3 [Mosaic TPU] Add general tpu.vector_store and support masked store.
This cl introduces a general store op called tpu.vector_stores which aims to unify vector::store, tpu::strided_load, vector::masked_store. The tpu.vector_stores should also provide general interface for lowering for both TensorCore and SparseCore.

This cl also adds the support for (dynamic) masked store.

PiperOrigin-RevId: 698067741
2024-11-19 10:33:09 -08:00
Dan Foreman-Mackey
3556a83334 Add missing version guard in GPU tests for jnp.poly.
jaxlib v0.4.35 is required for running `jnp.linalg.eig` on GPU which is required for `poly`.

PiperOrigin-RevId: 698052642
2024-11-19 09:52:45 -08:00
jax authors
6929a97c0c Merge pull request #24968 from nireekshak:testingbranch
PiperOrigin-RevId: 698051658
2024-11-19 09:50:34 -08:00
jax authors
9d3eda17fd Merge pull request #24942 from jeertmans:patch-1
PiperOrigin-RevId: 698031586
2024-11-19 08:44:30 -08:00
Jérome Eertmans
d912034cb5
fix(docs): typos in macro name
chore(docs): sync .md file
2024-11-19 16:42:19 +01:00
nireekshak
1458d3dd56 Fix some typos 2024-11-19 15:04:55 +00:00
jax authors
da50ad7ee3 [AutoPGLE] Use compile options to override debug options instead of XLA_FLAGS.
PiperOrigin-RevId: 697924164
2024-11-19 01:47:54 -08:00
jax authors
d397dd9684 Implement lax.pad in Pallas.
PiperOrigin-RevId: 697897093
2024-11-18 23:59:20 -08:00
jax authors
12a43f1fff Merge pull request #24853 from yliu120:check_proxy_envs
PiperOrigin-RevId: 697831284
2024-11-18 18:42:49 -08:00
Jevin Jiang
0fe77bc9f0 [Mosaic TPU] Support relayout for mask vector
We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished.

PiperOrigin-RevId: 697823543
2024-11-18 18:07:15 -08:00
jax authors
58103e5aee Merge pull request #24861 from yliu120:add_versions
PiperOrigin-RevId: 697822709
2024-11-18 18:04:38 -08:00
jax authors
4a9346e4b8 Merge pull request #24945 from hawkinsp:gamma
PiperOrigin-RevId: 697819608
2024-11-18 17:51:29 -08:00
Peter Hawkins
c5e8ae80f9 Update jax.scipy.special.gamma and gammasgn to return NaN for negative integer inputs.
Change to match upstream scipy: https://github.com/scipy/scipy/pull/21827.

Fixes #24875
2024-11-18 20:33:27 -05:00
Chris Jones
45c9c0a585 [pallas] Minor simplifications to Pallas interpreter.
BlockMappings are always present now.

PiperOrigin-RevId: 697807120
2024-11-18 17:10:10 -08:00