15048 Commits

Author SHA1 Message Date
Jake VanderPlas
83383fc717 Error on numpy array conversion of PRNG key array 2024-11-07 10:08:49 -08:00
Adam Paszke
de06584d98 [Mosaic GPU] Introduce a more flexible layout system
So far all of our layouts have been tailored to a limited set of use
cases we've tried so far, but they're still not general enough to
handle all of the register layouts needed for WGMMA or mixed precision
matmuls (incl. intermediate steps during conversions). Instead of adding
more special cases, I decided to adopt XLA tiled layouts and they do seem
to work quite well!

This change only lays the groundwork for the new layout system. Future
changes will build upon them to add new features and eventually replace
`WGMMA_LAYOUT` altogether.

PiperOrigin-RevId: 694105514
2024-11-07 07:08:51 -08:00
Adam Paszke
f8dba3c8a4 [Pallas:MGPU] Add support for multiple heads in attention
PiperOrigin-RevId: 694104006
2024-11-07 07:03:35 -08:00
Adam Paszke
506671291a [Mosaic GPU] Fix the ordering of transforms in async_copy
Previously we didn't really fully discharge squeezing the indexed
dims before applying other GMEM transforms, leading to potential
failures because they were not anticipating the increased rank.

PiperOrigin-RevId: 694098739
2024-11-07 06:41:42 -08:00
jax authors
4cc80889b6 Merge pull request #24645 from andportnoy:mosaic-gpu-improve-benchmarking-scripts
PiperOrigin-RevId: 694095765
2024-11-07 06:28:12 -08:00
jax authors
37af1002c7 Merge pull request #24602 from rdyro:rdyro-decode-attention-mask
PiperOrigin-RevId: 693835080
2024-11-06 13:05:29 -08:00
Jake VanderPlas
4e45a6d94d Remove some obsolete deprecation registrations
PiperOrigin-RevId: 693793727
2024-11-06 11:10:04 -08:00
jax authors
36017a4abb Merge pull request #24733 from jakevdp:fix-quantile
PiperOrigin-RevId: 693793427
2024-11-06 11:08:30 -08:00
jax authors
cbaafbbe99 Merge pull request #24723 from jakevdp:beta-dep
PiperOrigin-RevId: 693759757
2024-11-06 09:42:29 -08:00
jax authors
542cb2e57e Fix a bug in jax.scipy.stats.rankdata leading to breakage with shape polymorphism.
PiperOrigin-RevId: 693755546
2024-11-06 09:31:43 -08:00
Jake VanderPlas
d698da610a scipy.special.beta: remove deprecated x and y parameters 2024-11-06 09:01:27 -08:00
jax authors
0fcb2f3997 Merge pull request #24722 from jakevdp:tracer-hash
PiperOrigin-RevId: 693739151
2024-11-06 08:35:33 -08:00
Yash Katariya
b8c263a56c Add support for tpu v5e to jax.make_mesh
PiperOrigin-RevId: 693732928
2024-11-06 08:13:46 -08:00
Robert Dyro
d62510bfae Adding start index and kv_seq_len to decode kernel 2024-11-05 15:52:21 -08:00
Jake VanderPlas
44c6883cee Fix debug_nans false positive in jnp.quantile 2024-11-05 15:36:14 -08:00
jax authors
939b41f5bf Merge pull request #24717 from jakevdp:fix-rankdata
PiperOrigin-RevId: 693416741
2024-11-05 11:17:57 -08:00
jax authors
c1af808c8c Merge pull request #24710 from rajasekharporeddy:typos
PiperOrigin-RevId: 693412112
2024-11-05 11:06:20 -08:00
jax authors
497a5a35b4 Merge pull request #23468 from rdyro:rdyro-add-logging-env
PiperOrigin-RevId: 693395294
2024-11-05 10:23:24 -08:00
Robert Dyro
04f2ef9e93 Adding JAX_LOGGING_LEVEL configuration option 2024-11-05 09:56:46 -08:00
Jake VanderPlas
095bb0e742 Make Tracers non-hashable 2024-11-05 09:08:33 -08:00
Georg Stefan Schmid
7bdb2bf998 [jax.distributed] Enable grpc channel compression 2024-11-05 16:47:29 +00:00
Peter Hawkins
0e8acff5c6 Reverts a913fbf2fddc5b8c1b6c85b159d0eeb1bf65d461
PiperOrigin-RevId: 693360032
2024-11-05 08:32:25 -08:00
Dougal Maclaurin
478b750c29 Reverts f281c6f46475270a57a02416469226315377592c
PiperOrigin-RevId: 693339094
2024-11-05 07:17:14 -08:00
Jake VanderPlas
5f90f63d19 Improve efficiency of jax.scipy.stats.rankdata 2024-11-05 05:13:57 -08:00
Benjamin Chetioui
63e59c5fd7 [Mosaic GPU] Ensure that the dialect module can be loaded successfully.
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.

PiperOrigin-RevId: 693241875
2024-11-05 00:47:21 -08:00
rajasekharporeddy
a80d027dd7 Fix Typos 2024-11-05 12:29:20 +05:30
jax authors
a913fbf2fd rollback due to data race
Reverts ab47d4687f647de3aa145a9a782fb7b4aaf92af4

PiperOrigin-RevId: 693191298
2024-11-04 21:05:33 -08:00
Peter Hawkins
ab47d4687f [JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
2024-11-04 15:39:06 -08:00
Jake VanderPlas
e9acaa8484 Remove the initial argument to jax.nn.softmax and jax.nn.log_softmax.
This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.

PiperOrigin-RevId: 693023366
2024-11-04 10:55:21 -08:00
jax authors
26c0c5c764 Merge pull request #24692 from jiaxi98:cholesky_document
PiperOrigin-RevId: 693010148
2024-11-04 10:17:27 -08:00
jax authors
15febbf02d Merge pull request #24684 from hartikainen:fix-cuda_path
PiperOrigin-RevId: 693000918
2024-11-04 09:51:29 -08:00
Sergei Lebedev
d2bbd56405 [pallas:mosaic_gpu] lax.fori_loop lowering now promotes the carry to mgpu.FragmentedArrays
PiperOrigin-RevId: 692976037
2024-11-04 08:29:00 -08:00
Bart Chrzaszcz
3544efcade #sdy Fix Shardy bug where we weren't setting shmap in/out shardings as open.
If I revert the change in `shard_map.py`, then the unit test added `test_partial_auto_propagate_through` fails with:
```
self.assertEqual(actual.sharding, sharding)
AssertionError: Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec(), memory_kind=device) != Named[18 chars]('i': 2, 'j': 2), spec=PartitionSpec('i',), memory_kind=device)
```
PiperOrigin-RevId: 692971413
2024-11-04 08:12:35 -08:00
Kristian Hartikainen
9df719f83f Fix _cuda_path for case when cuda_nvcc is a namespace package
`cuda_nvcc`, when installed e.g. via `pip` in a `venv` comes out as a
namespace package. The previous logic found the `cuda_nvcc` import but
failed because `cuda_nvcc.__file__ is None`.
2024-11-04 18:06:55 +02:00
jiaxi98
95146deb6b issue #24691 2024-11-04 23:52:54 +08:00
Dougal Maclaurin
f281c6f464 Reverts ec39b592f7c096b0b8183723feaab2ed0d001041
PiperOrigin-RevId: 692949053
2024-11-04 06:54:06 -08:00
Dougal Maclaurin
ec39b592f7 Remove lattice system from JAX, especially raise_to_shaped (except as a no-op for backwards compat)
PiperOrigin-RevId: 692557993
2024-11-02 17:03:50 -07:00
Christos Perivolaropoulos
72eb5088b7 [jax] Mesh discharge rule should return None for inputs it did not touch.
PiperOrigin-RevId: 692519730
2024-11-02 12:13:14 -07:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Matthew Johnson
0f3ba4250d support exec_time_optimization_effort and memory_fitting_effort xla compilation
options

PiperOrigin-RevId: 692322944
2024-11-01 16:25:50 -07:00
Yash Katariya
fff33f90b2 Add compiler_options argument to jax.jit.
This exists on `Compiled` object via AOT too i.e. `jit(f).lower(*args).compile(compiler_options={})`

PiperOrigin-RevId: 692283964
2024-11-01 14:01:19 -07:00
Yash Katariya
07858fa98d [sharding_in_types] Allow device_put to reshard inputs. device_put is a good choice for resharding since it already handles transpose correctly because it tracks the src sharding too.
PiperOrigin-RevId: 692274137
2024-11-01 13:25:08 -07:00
Jake VanderPlas
97e8a4c8c6 Fix signatures test: new axis argument in trim_zeros 2024-11-01 10:15:31 -07:00
Naums Mogers
f462d7e586 [Mosaic] Set TPU CustomCall device type based on the core_type attribute
This CL deprecates the device_type parameter of `tpu_custom_call.as_tpu_kernel()` in favour of the `tpu.core_type` annotation.
The latter is more fine-grained: it is applied on `func.FuncOp` instead of the entire module, supports `tc`, `sc_scalar_subcore` and `sc_vector_subcore`.

`device_type` of the TPU CustomCall HLO is set to `sparsecore` if `sc_scalar_subcore` or `sc_vector_subcore` annotation is provided. Otherwise, `device_type` is not set and the CustomCall targets TC.

PiperOrigin-RevId: 692212644
2024-11-01 10:02:49 -07:00
jax authors
2a41c04fef Merge pull request #24652 from jakevdp:old-deps
PiperOrigin-RevId: 691995759
2024-10-31 18:10:38 -07:00
Ayaka
f60b97cea1 [Pallas TPU] Add lowering for lax.nextafter
Also improved the corresponding test cases to ensure better coverage and accuracy.

This PR is similar to https://github.com/jax-ml/jax/pull/22283, which adds lowering for `lax.sign`.

PiperOrigin-RevId: 691988164
2024-10-31 17:34:38 -07:00
Peter Hawkins
84c8794b30 Add a JaxIrContext that subclasses mlir.ir.Context and avoids calling ir.Context's __init__.
mlir.ir.Context has the unfortunate behavior that it loads all dialects linked into the binary, even those we have no intention of using. This is fairly benign in JAX's usual configuration, but if JAX is linked together with other MLIR-using software it can be problematic.

PiperOrigin-RevId: 691984229
2024-10-31 17:18:08 -07:00
jax authors
423cd2ad5e Simplified conditional in flash attention.
PiperOrigin-RevId: 691972341
2024-10-31 16:28:11 -07:00
Jake VanderPlas
2b9c73d10d Remove a number of expired deprecations.
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Tzu-Wei Sung
7af7a60dcc [Pallas:TPU] Use arith.divui for uint32 div.
PiperOrigin-RevId: 691939453
2024-10-31 14:37:47 -07:00