473 Commits

Author SHA1 Message Date
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Peter Hawkins
49c80e68d1 Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
jax authors
dce6ab4548 Reverts 2aaa7559f96e4bb7b0271665bf386bf3ba22c451
PiperOrigin-RevId: 584033001
2023-11-20 08:23:40 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jake VanderPlas
84aa7e5c53 Deprecate passing of None to jax.numpy.array 2023-11-16 15:10:56 -08:00
Peter Hawkins
234be736c4 Reverts ef9075159a67a2b94526b65e4a2c2904a4a49046
PiperOrigin-RevId: 582789416
2023-11-15 13:35:52 -08:00
carlosgmartin
9f8e1bc34a Add nn.squareplus. 2023-11-14 23:52:41 -05:00
Peter Hawkins
ef9075159a Reverts 6401db3775bace69989cd76ccd328fc9a6cf0964
PiperOrigin-RevId: 582275667
2023-11-14 04:31:54 -08:00
Peter Hawkins
6401db3775 Make the CPU backend participate in distributed initialization.
The main effect of this change is that CPU devices end up with a unique global ID and the correct process index.

PiperOrigin-RevId: 582127068
2023-11-13 16:55:12 -08:00
Peter Hawkins
cb182b8b22 Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs.
The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions.

Timings on T4 GPU:

Before:

------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           263587 ns       242274 ns         2780
svd/m:2/n:1           335561 ns       298238 ns         2303
svd/m:5/n:1           337784 ns       299841 ns         2304
svd/m:10/n:1          339184 ns       300703 ns         2311
svd/m:100/n:1         359826 ns       320088 ns         2159
svd/m:500/n:1         376124 ns       338660 ns         2076
svd/m:800/n:1         375779 ns       335590 ns         2060
svd/m:1000/n:1        419171 ns       341487 ns         2072
svd/m:1/n:2           307564 ns       270663 ns         2544
svd/m:2/n:2           320928 ns       283601 ns         2487
svd/m:5/n:2           377373 ns       344228 ns         2035
svd/m:10/n:2          380557 ns       349412 ns         1953
svd/m:100/n:2         435465 ns       403496 ns         1722
svd/m:500/n:2         444610 ns       410913 ns         1680
svd/m:800/n:2         454493 ns       416495 ns         1665
svd/m:1000/n:2        492110 ns       420539 ns         1665
svd/m:1/n:5           307316 ns       275833 ns         2531
svd/m:2/n:5           374318 ns       341432 ns         2086
svd/m:5/n:5           512928 ns       470293 ns         1361
svd/m:10/n:5          589330 ns       537070 ns         1353
svd/m:100/n:5         620164 ns       580166 ns         1193
svd/m:500/n:5         636424 ns       593692 ns         1180
svd/m:800/n:5         635545 ns       595016 ns         1181
svd/m:1000/n:5        672443 ns       597387 ns         1115
svd/m:1/n:10          310013 ns       273998 ns         2520
svd/m:2/n:10          370451 ns       334489 ns         2105
svd/m:5/n:10          560037 ns       522223 ns         1274
svd/m:10/n:10         572868 ns       535388 ns         1304
svd/m:100/n:10        959802 ns       918258 ns          765
svd/m:500/n:10        955958 ns       909778 ns          758
svd/m:800/n:10        924104 ns       879512 ns          777
svd/m:1000/n:10       950140 ns       883493 ns          775
svd/m:1/n:100         351237 ns       315554 ns         2198
svd/m:2/n:100         426883 ns       390089 ns         1792
svd/m:5/n:100         601557 ns       564493 ns         1255
svd/m:10/n:100        920819 ns       880011 ns          787
svd/m:100/n:100      7902281 ns      7229220 ns           95
svd/m:500/n:100      9720727 ns      9040679 ns           79
svd/m:800/n:100      9856378 ns      8998050 ns           79
svd/m:1000/n:100     9721017 ns      9086414 ns           79
svd/m:1/n:500         371171 ns       334217 ns         2117
svd/m:2/n:500         449165 ns       411499 ns         1700
svd/m:5/n:500         620354 ns       581866 ns         1185
svd/m:10/n:500        892375 ns       847239 ns          833
svd/m:100/n:500      9564810 ns      8867540 ns           79
svd/m:500/n:500    111924035 ns    104078023 ns            7
svd/m:800/n:500    147777319 ns    142730412 ns            5
svd/m:1000/n:500   154205084 ns    149740209 ns            5
svd/m:1/n:800         372122 ns       334212 ns         2119
svd/m:2/n:800         456672 ns       419260 ns         1680
svd/m:5/n:800         691208 ns       626003 ns         1190
svd/m:10/n:800       1017694 ns       941480 ns          730
svd/m:100/n:800      9892683 ns      9091043 ns           76
svd/m:500/n:800    144134235 ns    139129722 ns            5
svd/m:800/n:800    342790246 ns    333299774 ns            2
svd/m:1000/n:800   432820082 ns    427978978 ns            2
svd/m:1/n:1000        372785 ns       335745 ns         1805
svd/m:2/n:1000        451946 ns       413341 ns         1668
svd/m:5/n:1000        618475 ns       577213 ns         1169
svd/m:10/n:1000       907729 ns       863335 ns          808
svd/m:100/n:1000     9868543 ns      9116870 ns           76
svd/m:500/n:1000   156777811 ns    152042065 ns            5
svd/m:800/n:1000   429704070 ns    424677592 ns            2
svd/m:1000/n:1000  654864311 ns    642693162 ns            1

After:
------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           265980 ns       245433 ns         2791
svd/m:2/n:1           340203 ns       302783 ns         2288
svd/m:5/n:1           337807 ns       301916 ns         2286
svd/m:10/n:1          338064 ns       302441 ns         2297
svd/m:100/n:1         335444 ns       298440 ns         2327
svd/m:500/n:1         338025 ns       302096 ns         2272
svd/m:800/n:1         328382 ns       291740 ns         2252
svd/m:1000/n:1        397494 ns       310905 ns         2239
svd/m:1/n:2           310464 ns       274507 ns         2535
svd/m:2/n:2           319999 ns       284247 ns         2515
svd/m:5/n:2           373435 ns       335919 ns         2069
svd/m:10/n:2          376327 ns       339327 ns         2056
svd/m:100/n:2         385061 ns       349258 ns         2003
svd/m:500/n:2         392352 ns       355735 ns         1932
svd/m:800/n:2         410736 ns       370677 ns         1881
svd/m:1000/n:2        494326 ns       405603 ns         1721
svd/m:1/n:5           316735 ns       277292 ns         2538
svd/m:2/n:5           383748 ns       342218 ns         2077
svd/m:5/n:5           494204 ns       454309 ns         1476
svd/m:10/n:5          547017 ns       508184 ns         1371
svd/m:100/n:5         514537 ns       476761 ns         1460
svd/m:500/n:5         544656 ns       504877 ns         1381
svd/m:800/n:5         642590 ns       599314 ns         1159
svd/m:1000/n:5        706166 ns       621209 ns         1106
svd/m:1/n:10          310825 ns       274374 ns         2511
svd/m:2/n:10          381316 ns       344202 ns         2094
svd/m:5/n:10          565469 ns       526759 ns         1266
svd/m:10/n:10         576111 ns       537286 ns         1299
svd/m:100/n:10        653250 ns       613392 ns         1137
svd/m:500/n:10        690532 ns       645828 ns         1080
svd/m:800/n:10        763924 ns       723677 ns          959
svd/m:1000/n:10       940342 ns       855517 ns          818
svd/m:1/n:100         306134 ns       271533 ns         2526
svd/m:2/n:100         374680 ns       339298 ns         2071
svd/m:5/n:100         576926 ns       539062 ns         1228
svd/m:10/n:100        656806 ns       615171 ns         1123
svd/m:100/n:100      3295164 ns      3138621 ns          223
svd/m:500/n:100      4269347 ns      4166000 ns          168
svd/m:800/n:100      4656541 ns      4522247 ns          154
svd/m:1000/n:100     6479223 ns      6354578 ns          112
svd/m:1/n:500         329966 ns       289083 ns         2440
svd/m:2/n:500         407535 ns       366794 ns         1947
svd/m:5/n:500         567367 ns       522809 ns         1336
svd/m:10/n:500        712307 ns       657608 ns         1065
svd/m:100/n:500      4262986 ns      4169907 ns          167
svd/m:500/n:500     28824720 ns     28650258 ns           25
svd/m:800/n:500     29330139 ns     28677269 ns           25
svd/m:1000/n:500    30848037 ns     30089216 ns           23
svd/m:1/n:800         328620 ns       289181 ns         2329
svd/m:2/n:800         419052 ns       379483 ns         1876
svd/m:5/n:800         587366 ns       546979 ns         1269
svd/m:10/n:800        830762 ns       787923 ns          893
svd/m:100/n:800      4763633 ns      4595738 ns          152
svd/m:500/n:800     30447861 ns     29949714 ns           24
svd/m:800/n:800     94188958 ns     93488372 ns            8
svd/m:1000/n:800    94701529 ns     93394677 ns            7
svd/m:1/n:1000        351102 ns       313099 ns         2218
svd/m:2/n:1000        446543 ns       407807 ns         1708
svd/m:5/n:1000        661152 ns       616174 ns         1129
svd/m:10/n:1000       915743 ns       873397 ns          802
svd/m:100/n:1000     6434730 ns      6282779 ns          113
svd/m:500/n:1000    30244321 ns     29684290 ns           24
svd/m:800/n:1000    92727423 ns     91477078 ns            8
svd/m:1000/n:1000  169500709 ns    168358420 ns            4
PiperOrigin-RevId: 582041508
2023-11-13 12:04:13 -08:00
jax authors
2aaa7559f9 Reverts a80d35f4eda99bac42a1573ff5029e7035ebe71b
PiperOrigin-RevId: 581285053
2023-11-10 09:32:29 -08:00
jax authors
a80d35f4ed Replace gcc with clang compiler in CI nightly jobs.
PiperOrigin-RevId: 581086810
2023-11-09 18:30:03 -08:00
Jake VanderPlas
340e655ac2 Remove deprecated sym_pos argument from jax.scipy.linalg.solve
PiperOrigin-RevId: 580940755
2023-11-09 09:53:37 -08:00
Skye Wanderman-Milne
55e3072d2e Update versions and CHANGELOG after jax 0.4.20 release 2023-11-02 16:30:56 -07:00
Peter Hawkins
47a76df7cc [IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
These types didn't match between xla::PrimitiveType and ifrt::DType.

Add a static_assert to enforce equality.

PiperOrigin-RevId: 576288042
2023-10-24 14:38:00 -07:00
jax authors
77de0df709 update change log's formatting.
PiperOrigin-RevId: 576214943
2023-10-24 11:52:24 -07:00
Peter Hawkins
8d49f9a159 Reverts 9b1a656c1ef1f93f9e93eccb662de9bebe66b51a
PiperOrigin-RevId: 576128882
2023-10-24 07:08:34 -07:00
jax authors
9b1a656c1e Reverts 84c516974ab2b37169938d2d48a6c29a63c62c21
PiperOrigin-RevId: 575891656
2023-10-23 12:13:56 -07:00
jax authors
dde17cd5bc Merge pull request #18180 from carlosgmartin:fill_diagonal
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
Yash Katariya
613369fc22 Finish 0.4.19 jax and jaxlib release
PiperOrigin-RevId: 574983871
2023-10-19 13:27:52 -07:00
Peter Hawkins
d856ecc6fb Set RPATH, not RUNPATH in JAX CUDA builds.
Fixes https://github.com/google/jax/issues/17497
2023-10-12 09:38:10 -07:00
Jake VanderPlas
117f4bdf9b Define jax.typing.DTypeLike 2023-10-10 08:46:36 -07:00
Peter Hawkins
4611d13c07 Only perform compilation cache writes from process 0.
This avoids problems with contending writes on filesystems such as GCS.

PiperOrigin-RevId: 572032482
2023-10-09 13:55:07 -07:00
Skye Wanderman-Milne
a06beaa1a2 Update versions post jax 0.4.18 release 2023-10-06 17:20:34 -07:00
Skye Wanderman-Milne
d4a1bb9292 Update setup.py and CHANGELOG for jax 0.4.18 release 2023-10-06 13:13:33 -07:00
jax authors
f0e4ea23cc Merge pull request #17987 from jakevdp:lax-dep
PiperOrigin-RevId: 571401660
2023-10-06 12:23:20 -07:00
Peter Hawkins
8c4d020db9 Improve CUDA install documentation.
Mention NCCL as a dependency, since it will be required by the next jaxlib release.
Mention LD_LIBRARY_PATH and PATH as how one overrides the CUDA installation for local installs.

Fixes #17831
2023-10-06 14:36:29 -04:00
Jake VanderPlas
ce6a0c43ad jax.lax: deprecate inadvertent exports & internal utilities 2023-10-06 11:26:03 -07:00
Peter Hawkins
efc18e4147 [JAX] Obtain NCCL via a stub, rather than linking it statically or dynamically.
This shrinks the CUDA jaxlib wheel size by around 80MB.

PiperOrigin-RevId: 570554454
2023-10-03 18:33:58 -07:00
Skye Wanderman-Milne
82b58386b7 Update versions and CHANGELOG after jax 0.4.17 release 2023-10-03 17:54:35 -07:00
Jake VanderPlas
a09fdf6e2f Add jax.numpy.bitwise_count() 2023-10-03 13:48:16 -07:00
Jake VanderPlas
9247a62b2b Add CHANGELOG entry for the jnp annotation change 2023-10-02 11:31:28 -07:00
Peter Hawkins
b7dfde8d87 Add notes about the new CUDA version restrictions to the changelog and installation instructions. 2023-09-27 15:56:47 -04:00
Peter Hawkins
a2e1f1f24e Update changelog.
Bump the minimum CUDA 12 pip package versions to the current releases.
2023-09-26 18:21:51 -04:00
Peter Hawkins
2fd6df45e4 Fix test failures under SciPy 1.11 for scipy.stats.mode. 2023-09-23 20:15:51 +00:00
Jake VanderPlas
243a6a236c dtypes.issubdtype: validate a when b is dtypes.extended 2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f [random] deprecate named key creation functions 2023-09-21 13:57:49 -07:00
Ayaka
74bc42e53e
Fix typo in CHANGELOG.md 2023-09-21 14:37:19 +08:00
Jake VanderPlas
024b1f23d7 Remove deprecated submodule jax.abstract_arrays 2023-09-19 15:40:18 -07:00
Yash Katariya
dcc465b4de Finish jax and jaxlib 0.4.16 release
PiperOrigin-RevId: 566477931
2023-09-18 19:09:19 -07:00
Yash Katariya
a2720ee2c3 Deprecate jax.experimental.pjit.with_sharding_constraint. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00
Roy Frostig
1f8cc44f4e deprecate PRNGKeyArray.unsafe_raw_array in favor of jax.random.key_data
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.

PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Jake VanderPlas
4e6c1b68c7 Deprecate random.KeyArray and random.PRNGKeyArray 2023-09-13 14:05:42 -07:00
Jake VanderPlas
eeb32a7d1f Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped
PiperOrigin-RevId: 565142019
2023-09-13 13:21:46 -07:00
Jake VanderPlas
22ff7bd19a Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
These have been deprecated in JAX following similar deprecations in numpy v1.25.0

PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
408c657436 Add a release note about a fixed Windows crash. 2023-09-07 09:35:25 -04:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
Jake VanderPlas
4b89d03147 Deprecate the contents of jax.prng 2023-08-30 15:13:32 -07:00