Bill Varcho
0ed6eaeb4a
[SDY] fix JAX layouts tests for Shardy.
...
PiperOrigin-RevId: 697715276
2024-11-18 12:14:32 -08:00
jax authors
70b05f6cde
Merge pull request #24952 from jakevdp:fix-pyi
...
PiperOrigin-RevId: 697698420
2024-11-18 11:25:43 -08:00
Jake VanderPlas
5bebd0f6c4
fix typo in numpy/__init__.pyi
2024-11-18 11:04:33 -08:00
Yash Katariya
6fe7b1713a
Return SingleDeviceSharding instead of GSPMDShardings when there is only 1 device during compiled.input_shardings
call.
...
PiperOrigin-RevId: 697683233
2024-11-18 10:45:47 -08:00
jax authors
297a4e5ef5
Merge pull request #24903 from jakevdp:logsumexp
...
PiperOrigin-RevId: 697665013
2024-11-18 09:58:23 -08:00
Jake VanderPlas
e9864c69da
Make logaddexp and logaddexp2 into ufuncs
2024-11-18 09:27:36 -08:00
jax authors
2de40e7dbf
Merge pull request #24916 from jakevdp:update-lp
...
PiperOrigin-RevId: 697652214
2024-11-18 09:19:09 -08:00
jax authors
05d66d7cd5
Merge pull request #24912 from jakevdp:jnp-module
...
PiperOrigin-RevId: 697646272
2024-11-18 09:01:27 -08:00
Nitin Srinivasan
14187399d7
Add new CI script for running Bazel GPU presubmits
...
PiperOrigin-RevId: 697643622
2024-11-18 08:52:51 -08:00
jax authors
65f9c7855e
Merge pull request #24932 from hawkinsp:gather
...
PiperOrigin-RevId: 697632284
2024-11-18 08:14:23 -08:00
Dan Foreman-Mackey
ccb331707e
Add a GPU implementation of lax.linalg.eig
.
...
This feature has been in the queue for a long time (see https://github.com/jax-ml/jax/issues/1259 ), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (https://github.com/jax-ml/jax/issues/24255 ; this should be investigated separately).
This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/ ) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)
We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_use_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.
PiperOrigin-RevId: 697631402
2024-11-18 08:11:57 -08:00
jax authors
afdc79271c
Merge pull request #24933 from hawkinsp:pow
...
PiperOrigin-RevId: 697622037
2024-11-18 07:38:56 -08:00
jax authors
f7ae0f99fe
Merge pull request #24930 from hawkinsp:dedup
...
PiperOrigin-RevId: 697544119
2024-11-18 01:58:51 -08:00
jax authors
ed250b8983
[AutoPGLE] Temporary disable pgle_test in the OSS.
...
PiperOrigin-RevId: 697517161
2024-11-17 23:59:29 -08:00
jax authors
742cabc547
Update XLA dependency to use revision
...
58ea2935b4
.
PiperOrigin-RevId: 697425145
2024-11-17 14:19:34 -08:00
Peter Hawkins
8a6c560b25
Use a direct StableHLO lowering for pow.
...
This is slightly faster than lowering via tracing, and the code is simpler also.
2024-11-16 14:29:20 -08:00
jax authors
7b9914d711
Update XLA dependency to use revision
...
9ab7d704d7
.
PiperOrigin-RevId: 697222155
2024-11-16 13:40:11 -08:00
Peter Hawkins
1d519f4ce3
Return a ndarray in shape_as_value if the shape is known to be constant.
2024-11-16 13:38:23 -08:00
Peter Hawkins
626aea017b
Deduplicate constants in StableHLO lowering.
...
The goal of this change is to reduce the size of the generated code: we frequently built thousands of scalar 0s, for example.
2024-11-16 12:05:26 -08:00
Yash Katariya
8525ef2b23
[sharding_in_types] Don't emit a wsc under full manual mode to avoid increasing HLO size by a lot
...
PiperOrigin-RevId: 697048126
2024-11-15 17:42:16 -08:00
jax authors
efd232762c
Merge pull request #24917 from emilyfertig:emilyaf-sharp-bits
...
PiperOrigin-RevId: 697020253
2024-11-15 15:37:16 -08:00
Emily Fertig
225a2a5f8b
Consolidate material on PRNGs and add a short summary to Key Concepts.
2024-11-15 14:44:57 -08:00
jax authors
1aa5de66a8
Merge pull request #24914 from jakevdp:fix-pyi
...
PiperOrigin-RevId: 696989967
2024-11-15 13:49:53 -08:00
jax authors
605c605181
Merge pull request #24918 from emilyfertig:emilyaf-logical-op-example
...
PiperOrigin-RevId: 696989966
2024-11-15 13:48:14 -08:00
barnesjoseph
81cdc882ae
DOC: update main landing page style
...
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2024-11-15 13:44:31 -08:00
jax authors
1780ff2964
Update XLA dependency to use revision
...
195f45b708
.
PiperOrigin-RevId: 696984108
2024-11-15 13:28:31 -08:00
Emily Fertig
5f1e3f5644
Add an example on logical operators to the tutorial.
2024-11-15 12:40:41 -08:00
Jake VanderPlas
5f94284432
Add missing functions to jax.numpy type interface
2024-11-15 12:14:55 -08:00
jax authors
d8085008b7
Merge pull request #24913 from hawkinsp:threefry
...
PiperOrigin-RevId: 696915844
2024-11-15 09:42:54 -08:00
Peter Hawkins
23e9142d28
Lower threefry as an out-of-line MLIR function on TPU.
...
On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
2024-11-15 08:49:35 -08:00
jax authors
1471702adc
[Mosaic TPU] Support 1D concat: set implicit_dim to kSecondMinor to treat 1D (N,) as (1, N) and then tile it as (1, 128)
...
PiperOrigin-RevId: 696870258
2024-11-15 06:41:57 -08:00
Jake VanderPlas
f652b6ad6a
Set __module__ attribute for objects in jax.numpy
2024-11-15 06:03:54 -08:00
Yash Katariya
9a0e9e55d8
[sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in wrap_with_sharding_op
.
...
Also lower shardings with `Collective` axes correctly to HloSharding.
PiperOrigin-RevId: 696703030
2024-11-14 17:32:01 -08:00
jax authors
4511f0c66b
Merge pull request #24862 from emilyfertig:emilyaf-control-flow-tutorial
...
PiperOrigin-RevId: 696692588
2024-11-14 16:50:14 -08:00
jax authors
1c31860fad
Merge pull request #24907 from jakevdp:array-api
...
PiperOrigin-RevId: 696688602
2024-11-14 16:33:23 -08:00
jax authors
c6051b3e15
Merge pull request #24881 from dfm:ffi-call-rep-rule
...
PiperOrigin-RevId: 696681818
2024-11-14 16:08:47 -08:00
Jake VanderPlas
a115b2cec5
Update array-api-tests commit
2024-11-14 16:05:30 -08:00
jax authors
4fe9164548
Merge pull request #24871 from carlosgmartin:numpy_put_along_axis
...
PiperOrigin-RevId: 696679735
2024-11-14 16:00:51 -08:00
jax authors
5764afb4b3
Merge pull request #24905 from jakevdp:old-arg
...
PiperOrigin-RevId: 696679346
2024-11-14 15:58:49 -08:00
jax authors
8e292122b7
Merge pull request #24567 from Intel-tensorflow:minigoel/intel-plugin
...
PiperOrigin-RevId: 696677564
2024-11-14 15:52:38 -08:00
Dan Foreman-Mackey
41a0493e56
Add shard map replication rule for ffi_call.
2024-11-14 15:44:31 -08:00
jax authors
c40d405e43
Update XLA dependency to use revision
...
ecdba3f23b
.
PiperOrigin-RevId: 696642961
2024-11-14 14:04:36 -08:00
Jake VanderPlas
4a3e1155b9
cleanup: delete unused argument from internal reduction helper
2024-11-14 13:07:15 -08:00
jax authors
04d339d665
Merge pull request #24904 from jakevdp:array-api
...
PiperOrigin-RevId: 696623038
2024-11-14 13:00:42 -08:00
carlosgmartin
1f114b1cf7
Add numpy.put_along_axis.
2024-11-14 15:23:26 -05:00
jax authors
303b792ac2
Merge pull request #24864 from jakevdp:logaddexp2
...
PiperOrigin-RevId: 696602883
2024-11-14 11:56:42 -08:00
Jake VanderPlas
d0f36666ff
Update array-api-tests commit
2024-11-14 11:52:21 -08:00
Jake VanderPlas
d823f1720d
jnp.logaddexp2: simplify implementation
2024-11-14 11:35:23 -08:00
Emily Fertig
e6f6a8af8d
Move Control Flow text from Sharp Bits into its own tutorial.
2024-11-14 11:07:52 -08:00
jax authors
19a51de2ab
Merge pull request #24897 from hawkinsp:ipow
...
PiperOrigin-RevId: 696581990
2024-11-14 10:58:02 -08:00