24188 Commits

Author SHA1 Message Date
Bill Varcho
0ed6eaeb4a [SDY] fix JAX layouts tests for Shardy.
PiperOrigin-RevId: 697715276
2024-11-18 12:14:32 -08:00
jax authors
70b05f6cde Merge pull request #24952 from jakevdp:fix-pyi
PiperOrigin-RevId: 697698420
2024-11-18 11:25:43 -08:00
Jake VanderPlas
5bebd0f6c4 fix typo in numpy/__init__.pyi 2024-11-18 11:04:33 -08:00
Yash Katariya
6fe7b1713a Return SingleDeviceSharding instead of GSPMDShardings when there is only 1 device during compiled.input_shardings call.
PiperOrigin-RevId: 697683233
2024-11-18 10:45:47 -08:00
jax authors
297a4e5ef5 Merge pull request #24903 from jakevdp:logsumexp
PiperOrigin-RevId: 697665013
2024-11-18 09:58:23 -08:00
Jake VanderPlas
e9864c69da Make logaddexp and logaddexp2 into ufuncs 2024-11-18 09:27:36 -08:00
jax authors
2de40e7dbf Merge pull request #24916 from jakevdp:update-lp
PiperOrigin-RevId: 697652214
2024-11-18 09:19:09 -08:00
jax authors
05d66d7cd5 Merge pull request #24912 from jakevdp:jnp-module
PiperOrigin-RevId: 697646272
2024-11-18 09:01:27 -08:00
Nitin Srinivasan
14187399d7 Add new CI script for running Bazel GPU presubmits
PiperOrigin-RevId: 697643622
2024-11-18 08:52:51 -08:00
jax authors
65f9c7855e Merge pull request #24932 from hawkinsp:gather
PiperOrigin-RevId: 697632284
2024-11-18 08:14:23 -08:00
Dan Foreman-Mackey
ccb331707e Add a GPU implementation of lax.linalg.eig.
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
afdc79271c Merge pull request #24933 from hawkinsp:pow
PiperOrigin-RevId: 697622037
2024-11-18 07:38:56 -08:00
jax authors
f7ae0f99fe Merge pull request #24930 from hawkinsp:dedup
PiperOrigin-RevId: 697544119
2024-11-18 01:58:51 -08:00
jax authors
ed250b8983 [AutoPGLE] Temporary disable pgle_test in the OSS.
PiperOrigin-RevId: 697517161
2024-11-17 23:59:29 -08:00
jax authors
742cabc547 Update XLA dependency to use revision
58ea2935b4.

PiperOrigin-RevId: 697425145
2024-11-17 14:19:34 -08:00
Peter Hawkins
8a6c560b25 Use a direct StableHLO lowering for pow.
This is slightly faster than lowering via tracing, and the code is simpler also.
2024-11-16 14:29:20 -08:00
jax authors
7b9914d711 Update XLA dependency to use revision
9ab7d704d7.

PiperOrigin-RevId: 697222155
2024-11-16 13:40:11 -08:00
Peter Hawkins
1d519f4ce3 Return a ndarray in shape_as_value if the shape is known to be constant. 2024-11-16 13:38:23 -08:00
Peter Hawkins
626aea017b Deduplicate constants in StableHLO lowering.
The goal of this change is to reduce the size of the generated code: we frequently built thousands of scalar 0s, for example.
2024-11-16 12:05:26 -08:00
Yash Katariya
8525ef2b23 [sharding_in_types] Don't emit a wsc under full manual mode to avoid increasing HLO size by a lot
PiperOrigin-RevId: 697048126
2024-11-15 17:42:16 -08:00
jax authors
efd232762c Merge pull request #24917 from emilyfertig:emilyaf-sharp-bits
PiperOrigin-RevId: 697020253
2024-11-15 15:37:16 -08:00
Emily Fertig
225a2a5f8b Consolidate material on PRNGs and add a short summary to Key Concepts. 2024-11-15 14:44:57 -08:00
jax authors
1aa5de66a8 Merge pull request #24914 from jakevdp:fix-pyi
PiperOrigin-RevId: 696989967
2024-11-15 13:49:53 -08:00
jax authors
605c605181 Merge pull request #24918 from emilyfertig:emilyaf-logical-op-example
PiperOrigin-RevId: 696989966
2024-11-15 13:48:14 -08:00
barnesjoseph
81cdc882ae DOC: update main landing page style
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2024-11-15 13:44:31 -08:00
jax authors
1780ff2964 Update XLA dependency to use revision
195f45b708.

PiperOrigin-RevId: 696984108
2024-11-15 13:28:31 -08:00
Emily Fertig
5f1e3f5644 Add an example on logical operators to the tutorial. 2024-11-15 12:40:41 -08:00
Jake VanderPlas
5f94284432 Add missing functions to jax.numpy type interface 2024-11-15 12:14:55 -08:00
jax authors
d8085008b7 Merge pull request #24913 from hawkinsp:threefry
PiperOrigin-RevId: 696915844
2024-11-15 09:42:54 -08:00
Peter Hawkins
23e9142d28 Lower threefry as an out-of-line MLIR function on TPU.
On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
2024-11-15 08:49:35 -08:00
jax authors
1471702adc [Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to treat 1D (N,) as (1, N) and then tile it as (1, 128)
PiperOrigin-RevId: 696870258
2024-11-15 06:41:57 -08:00
Jake VanderPlas
f652b6ad6a Set __module__ attribute for objects in jax.numpy 2024-11-15 06:03:54 -08:00
Yash Katariya
9a0e9e55d8 [sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in wrap_with_sharding_op.
Also lower shardings with `Collective` axes correctly to HloSharding.

PiperOrigin-RevId: 696703030
2024-11-14 17:32:01 -08:00
jax authors
4511f0c66b Merge pull request #24862 from emilyfertig:emilyaf-control-flow-tutorial
PiperOrigin-RevId: 696692588
2024-11-14 16:50:14 -08:00
jax authors
1c31860fad Merge pull request #24907 from jakevdp:array-api
PiperOrigin-RevId: 696688602
2024-11-14 16:33:23 -08:00
jax authors
c6051b3e15 Merge pull request #24881 from dfm:ffi-call-rep-rule
PiperOrigin-RevId: 696681818
2024-11-14 16:08:47 -08:00
Jake VanderPlas
a115b2cec5 Update array-api-tests commit 2024-11-14 16:05:30 -08:00
jax authors
4fe9164548 Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
PiperOrigin-RevId: 696679735
2024-11-14 16:00:51 -08:00
jax authors
5764afb4b3 Merge pull request #24905 from jakevdp:old-arg
PiperOrigin-RevId: 696679346
2024-11-14 15:58:49 -08:00
jax authors
8e292122b7 Merge pull request #24567 from Intel-tensorflow:minigoel/intel-plugin
PiperOrigin-RevId: 696677564
2024-11-14 15:52:38 -08:00
Dan Foreman-Mackey
41a0493e56 Add shard map replication rule for ffi_call. 2024-11-14 15:44:31 -08:00
jax authors
c40d405e43 Update XLA dependency to use revision
ecdba3f23b.

PiperOrigin-RevId: 696642961
2024-11-14 14:04:36 -08:00
Jake VanderPlas
4a3e1155b9 cleanup: delete unused argument from internal reduction helper 2024-11-14 13:07:15 -08:00
jax authors
04d339d665 Merge pull request #24904 from jakevdp:array-api
PiperOrigin-RevId: 696623038
2024-11-14 13:00:42 -08:00
carlosgmartin
1f114b1cf7 Add numpy.put_along_axis. 2024-11-14 15:23:26 -05:00
jax authors
303b792ac2 Merge pull request #24864 from jakevdp:logaddexp2
PiperOrigin-RevId: 696602883
2024-11-14 11:56:42 -08:00
Jake VanderPlas
d0f36666ff Update array-api-tests commit 2024-11-14 11:52:21 -08:00
Jake VanderPlas
d823f1720d jnp.logaddexp2: simplify implementation 2024-11-14 11:35:23 -08:00
Emily Fertig
e6f6a8af8d Move Control Flow text from Sharp Bits into its own tutorial. 2024-11-14 11:07:52 -08:00
jax authors
19a51de2ab Merge pull request #24897 from hawkinsp:ipow
PiperOrigin-RevId: 696581990
2024-11-14 10:58:02 -08:00