24244 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
6222592625 Fix KeyError recently introduced in cloud_tpu_init.py
This fixes a bug introduced in https://github.com/jax-ml/jax/pull/24889
2024-11-20 17:46:06 +00:00
jax authors
439d34da15 Merge pull request #25005 from jakevdp:py313
PiperOrigin-RevId: 698413430
2024-11-20 09:15:03 -08:00
jax authors
800add2a03 Merge pull request #25007 from jakevdp:deps
PiperOrigin-RevId: 698413340
2024-11-20 09:13:05 -08:00
Jake VanderPlas
85e2969aea Deprecate several private APIs in jax.lib 2024-11-20 08:48:26 -08:00
Chris Jones
1e9e85a39e Simplify handling of DotAlgorithmPreset output types.
Create a clear distinction between the type used for accumulation and possible output types.

PiperOrigin-RevId: 698399447
2024-11-20 08:26:44 -08:00
Jake VanderPlas
a4266b5e31 Mention python 3.13 in docs & package metadata 2024-11-20 08:23:19 -08:00
jax authors
a582df0297 Update XLA dependency to use revision
fcee07f619.

PiperOrigin-RevId: 698371906
2024-11-20 06:39:06 -08:00
Sergei Lebedev
1df4b5f798 [pallas] Do not skip vmap tests on GPU when x64 is enabled
PiperOrigin-RevId: 698351984
2024-11-20 05:08:23 -08:00
Sergei Lebedev
04e4c69f7f [mosaic_gpu] Handle older jaxlibs in the profiler module
`measure` now raises a `RuntimeError` if the available `jaxlib` does not have
the required custom calls.

PiperOrigin-RevId: 698351662
2024-11-20 05:06:24 -08:00
Sergei Lebedev
f442d40f92 [mosaic_gpu] Fixed FragmentedArray comparisons with literals
PiperOrigin-RevId: 698343858
2024-11-20 04:31:28 -08:00
Sergei Lebedev
c76e5fe9a0 [pallas:mosaic_gpu] copy_smem_to_gmem now supports wait_read_only
PiperOrigin-RevId: 698343812
2024-11-20 04:29:33 -08:00
Peter Buchlovsky
14da7ebb76 [pallas:mosaic_gpu] Add Pallas Mosaic GPU lowering for jax.lax.bitcast_convert_type.
Only handles the case where operand type and target type have the same bitwidth.

PiperOrigin-RevId: 698332564
2024-11-20 03:41:19 -08:00
Peter Buchlovsky
1afb05e2e2 [mosaic_gpu] Fix signedness handling in FragmentedArray._pointwise.
Only propagate signedness from operands when the output type of `op` is an `ir.IntegerType`.

PiperOrigin-RevId: 698324596
2024-11-20 03:01:48 -08:00
jax authors
ae46b7564e Merge pull request #24593 from froystig:random-dtypes
PiperOrigin-RevId: 698268678
2024-11-19 23:04:06 -08:00
jax authors
4d60db1741 Add test_compute_on_host_shared_sharding in memories_test
PiperOrigin-RevId: 698250352
2024-11-19 21:33:27 -08:00
Roy Frostig
4bb81075bc represent random.key_impl of builtin RNGs by canonical string name
We do not have great reason to return specs here, and sticking to
strings instead can help with simple serialization.
2024-11-19 20:58:10 -08:00
Naums Mogers
6c291d67b7 [Mosaic] Add tpu.log verification on SC
Guards against using formatting and targeting vector subcores on SC.

PiperOrigin-RevId: 698222100
2024-11-19 19:04:29 -08:00
Peter Hawkins
867a36189b Fix a bug where constant deduplication used an inappropriate inequality.
We need to compare constants for bitwise equality, not, e.g., floating point equality. The change that added deduplication caused us to conflate +0.0 and -0.0, which led a downstream test not to terminate.

PiperOrigin-RevId: 698221147
2024-11-19 18:59:49 -08:00
Jake VanderPlas
8c71d1ad6d Make deprecated jax.experimental.array_api module visibility internal-only
This is in preparation for the module to be removed.

PiperOrigin-RevId: 698215225
2024-11-19 18:33:07 -08:00
Naums Mogers
c04aec9d52 [Mosaic] Extend tpu.sem_signal with subcore_id
This change:
- Bumps up the version of Mosaic to 4 in `serde.cc`.

- Adds optional `subcore_id` parameter to `tpu.sem_signal` for signalling specific subcores.

- Extends deserialization to correctly parse the older versions of Mosaic without the new parameter `subcore_id` of `tpu.sem_signal`.

PiperOrigin-RevId: 698163836
2024-11-19 15:22:59 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Mason Chang
42fbd301fc Move JAX example to public XLA:CPU API
PiperOrigin-RevId: 698143471
2024-11-19 14:19:29 -08:00
jax authors
3161a28424 Update XLA dependency to use revision
229f376e04.

PiperOrigin-RevId: 698136955
2024-11-19 14:01:25 -08:00
Naums Mogers
0d36b0b433 [Mosaic] Add target core type parameter to tpu.sem_signal
Adds the optional core type parameter to `tpu.sem_signal` for cross-core signalling.
If the target core type is not provided, the target core type is assumed to be that of the core issuing the signal.
The issuing core type is determined based on the core type annotation of the parent function; if the annotation is not provided, the issuing core type is assumed to be TensorCore.

PiperOrigin-RevId: 698129842
2024-11-19 13:40:13 -08:00
Sergei Lebedev
1bf70fbbc4 [pallas:mosaic_gpu] copy_gmem_to_smem no longer requires barrier to be a keyword argument
... because there really isn't any reason to require that.

PiperOrigin-RevId: 698116984
2024-11-19 13:02:35 -08:00
jax authors
2075b091c4 Merge pull request #24970 from hawkinsp:split
PiperOrigin-RevId: 698112383
2024-11-19 12:48:55 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
Dan Foreman-Mackey
a59bbb7cd7 Add test utility for accessing jaxlib version tuple.
We frequently need to condition tests on the current version of jaxlib. This change exposes the version tuple directly as part of `jtu` so that we don't need to import `jax._src.lib.version` in the tests.

PiperOrigin-RevId: 698097487
2024-11-19 12:00:32 -08:00
Justin Fu
c44f11d15e Add alternate implementation of threefry as a pallas kernel.
Current restrictions:
1) Dynamic grid sizes are not supported yet. This could in theory allow us to not recompile the kernel for different shapes.
2) fold_in and split still use the original rules. But there isn't a huge benefit to using the kernel right now since the input is so small and we can't avoid re-compilation due to (1).
3) Currently doesn't support high bits on the counter, meaning we can generate at max 4B numbers in one call. This is a fringe use-case since we only support 32-bit, and generating 4B 32-bit numbers would consume 16GB of HBM (an entire TPU v5p worth of HBM).

PiperOrigin-RevId: 698086352
2024-11-19 11:26:30 -08:00
Jevin Jiang
6c31efa3f3 [Mosaic TPU] Add general tpu.vector_store and support masked store.
This cl introduces a general store op called tpu.vector_stores which aims to unify vector::store, tpu::strided_load, vector::masked_store. The tpu.vector_stores should also provide general interface for lowering for both TensorCore and SparseCore.

This cl also adds the support for (dynamic) masked store.

PiperOrigin-RevId: 698067741
2024-11-19 10:33:09 -08:00
Dan Foreman-Mackey
3556a83334 Add missing version guard in GPU tests for jnp.poly.
jaxlib v0.4.35 is required for running `jnp.linalg.eig` on GPU which is required for `poly`.

PiperOrigin-RevId: 698052642
2024-11-19 09:52:45 -08:00
jax authors
6929a97c0c Merge pull request #24968 from nireekshak:testingbranch
PiperOrigin-RevId: 698051658
2024-11-19 09:50:34 -08:00
jax authors
9d3eda17fd Merge pull request #24942 from jeertmans:patch-1
PiperOrigin-RevId: 698031586
2024-11-19 08:44:30 -08:00
Jérome Eertmans
d912034cb5
fix(docs): typos in macro name
chore(docs): sync .md file
2024-11-19 16:42:19 +01:00
nireekshak
1458d3dd56 Fix some typos 2024-11-19 15:04:55 +00:00
jax authors
da50ad7ee3 [AutoPGLE] Use compile options to override debug options instead of XLA_FLAGS.
PiperOrigin-RevId: 697924164
2024-11-19 01:47:54 -08:00
jax authors
d397dd9684 Implement lax.pad in Pallas.
PiperOrigin-RevId: 697897093
2024-11-18 23:59:20 -08:00
jax authors
12a43f1fff Merge pull request #24853 from yliu120:check_proxy_envs
PiperOrigin-RevId: 697831284
2024-11-18 18:42:49 -08:00
Jevin Jiang
0fe77bc9f0 [Mosaic TPU] Support relayout for mask vector
We cast i1 vector (mask) to i32 vector before relayout and then cast back to i1 vector (mask) after relayout is finished.

PiperOrigin-RevId: 697823543
2024-11-18 18:07:15 -08:00
jax authors
58103e5aee Merge pull request #24861 from yliu120:add_versions
PiperOrigin-RevId: 697822709
2024-11-18 18:04:38 -08:00
jax authors
4a9346e4b8 Merge pull request #24945 from hawkinsp:gamma
PiperOrigin-RevId: 697819608
2024-11-18 17:51:29 -08:00
Peter Hawkins
c5e8ae80f9 Update jax.scipy.special.gamma and gammasgn to return NaN for negative integer inputs.
Change to match upstream scipy: https://github.com/scipy/scipy/pull/21827.

Fixes #24875
2024-11-18 20:33:27 -05:00
Chris Jones
45c9c0a585 [pallas] Minor simplifications to Pallas interpreter.
BlockMappings are always present now.

PiperOrigin-RevId: 697807120
2024-11-18 17:10:10 -08:00
Yash Katariya
2c68569af0 Fix a bug where mesh checking was not correct
PiperOrigin-RevId: 697792885
2024-11-18 16:22:27 -08:00
Yash Katariya
e904c177f7 Delete _normalized_spec from NamedSharding
PiperOrigin-RevId: 697779844
2024-11-18 15:35:38 -08:00
jax authors
6952ddf4c6 Merge pull request #24958 from barnesjoseph:add-font-fallback
PiperOrigin-RevId: 697774925
2024-11-18 15:18:34 -08:00
barnesjoseph
d4316b5760 Adds font fallbacks 2024-11-18 14:46:10 -08:00
jax authors
91891cb600 Merge pull request #23585 from apivovarov:float8_e4m3
PiperOrigin-RevId: 697760985
2024-11-18 14:34:59 -08:00
jax authors
b3ca6c47cc Update XLA dependency to use revision
082a701470.

PiperOrigin-RevId: 697756717
2024-11-18 14:21:51 -08:00
jax authors
16ed283f5e Merge pull request #24957 from hawkinsp:arm
PiperOrigin-RevId: 697755310
2024-11-18 14:17:47 -08:00