Jake VanderPlas
0aec40a16f
Deprecate arr.device_buffer and arr.device_buffers
2023-11-29 15:31:01 -08:00
Jake VanderPlas
13dd5e42cc
Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv
2023-11-28 13:55:18 -08:00
Peter Hawkins
84c1e825c0
Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
...
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Peter Hawkins
49c80e68d1
Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
...
Improve the documentation of lax.eig().
Fixes https://github.com/google/jax/issues/18226
PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
jax authors
dce6ab4548
Reverts 2aaa7559f96e4bb7b0271665bf386bf3ba22c451
...
PiperOrigin-RevId: 584033001
2023-11-20 08:23:40 -08:00
Peter Hawkins
30a0136813
Increase minimum jaxlib version to 0.4.19.
...
0.4.19 has xla_extension version 207 and mlir_api_version 54.
PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jake VanderPlas
84aa7e5c53
Deprecate passing of None to jax.numpy.array
2023-11-16 15:10:56 -08:00
Peter Hawkins
234be736c4
Reverts ef9075159a67a2b94526b65e4a2c2904a4a49046
...
PiperOrigin-RevId: 582789416
2023-11-15 13:35:52 -08:00
carlosgmartin
9f8e1bc34a
Add nn.squareplus.
2023-11-14 23:52:41 -05:00
Peter Hawkins
ef9075159a
Reverts 6401db3775bace69989cd76ccd328fc9a6cf0964
...
PiperOrigin-RevId: 582275667
2023-11-14 04:31:54 -08:00
Peter Hawkins
6401db3775
Make the CPU backend participate in distributed initialization.
...
The main effect of this change is that CPU devices end up with a unique global ID and the correct process index.
PiperOrigin-RevId: 582127068
2023-11-13 16:55:12 -08:00
Peter Hawkins
cb182b8b22
Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs.
...
The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions.
Timings on T4 GPU:
Before:
------------------------------------------------------------
Benchmark Time CPU Iterations
------------------------------------------------------------
svd/m:1/n:1 263587 ns 242274 ns 2780
svd/m:2/n:1 335561 ns 298238 ns 2303
svd/m:5/n:1 337784 ns 299841 ns 2304
svd/m:10/n:1 339184 ns 300703 ns 2311
svd/m:100/n:1 359826 ns 320088 ns 2159
svd/m:500/n:1 376124 ns 338660 ns 2076
svd/m:800/n:1 375779 ns 335590 ns 2060
svd/m:1000/n:1 419171 ns 341487 ns 2072
svd/m:1/n:2 307564 ns 270663 ns 2544
svd/m:2/n:2 320928 ns 283601 ns 2487
svd/m:5/n:2 377373 ns 344228 ns 2035
svd/m:10/n:2 380557 ns 349412 ns 1953
svd/m:100/n:2 435465 ns 403496 ns 1722
svd/m:500/n:2 444610 ns 410913 ns 1680
svd/m:800/n:2 454493 ns 416495 ns 1665
svd/m:1000/n:2 492110 ns 420539 ns 1665
svd/m:1/n:5 307316 ns 275833 ns 2531
svd/m:2/n:5 374318 ns 341432 ns 2086
svd/m:5/n:5 512928 ns 470293 ns 1361
svd/m:10/n:5 589330 ns 537070 ns 1353
svd/m:100/n:5 620164 ns 580166 ns 1193
svd/m:500/n:5 636424 ns 593692 ns 1180
svd/m:800/n:5 635545 ns 595016 ns 1181
svd/m:1000/n:5 672443 ns 597387 ns 1115
svd/m:1/n:10 310013 ns 273998 ns 2520
svd/m:2/n:10 370451 ns 334489 ns 2105
svd/m:5/n:10 560037 ns 522223 ns 1274
svd/m:10/n:10 572868 ns 535388 ns 1304
svd/m:100/n:10 959802 ns 918258 ns 765
svd/m:500/n:10 955958 ns 909778 ns 758
svd/m:800/n:10 924104 ns 879512 ns 777
svd/m:1000/n:10 950140 ns 883493 ns 775
svd/m:1/n:100 351237 ns 315554 ns 2198
svd/m:2/n:100 426883 ns 390089 ns 1792
svd/m:5/n:100 601557 ns 564493 ns 1255
svd/m:10/n:100 920819 ns 880011 ns 787
svd/m:100/n:100 7902281 ns 7229220 ns 95
svd/m:500/n:100 9720727 ns 9040679 ns 79
svd/m:800/n:100 9856378 ns 8998050 ns 79
svd/m:1000/n:100 9721017 ns 9086414 ns 79
svd/m:1/n:500 371171 ns 334217 ns 2117
svd/m:2/n:500 449165 ns 411499 ns 1700
svd/m:5/n:500 620354 ns 581866 ns 1185
svd/m:10/n:500 892375 ns 847239 ns 833
svd/m:100/n:500 9564810 ns 8867540 ns 79
svd/m:500/n:500 111924035 ns 104078023 ns 7
svd/m:800/n:500 147777319 ns 142730412 ns 5
svd/m:1000/n:500 154205084 ns 149740209 ns 5
svd/m:1/n:800 372122 ns 334212 ns 2119
svd/m:2/n:800 456672 ns 419260 ns 1680
svd/m:5/n:800 691208 ns 626003 ns 1190
svd/m:10/n:800 1017694 ns 941480 ns 730
svd/m:100/n:800 9892683 ns 9091043 ns 76
svd/m:500/n:800 144134235 ns 139129722 ns 5
svd/m:800/n:800 342790246 ns 333299774 ns 2
svd/m:1000/n:800 432820082 ns 427978978 ns 2
svd/m:1/n:1000 372785 ns 335745 ns 1805
svd/m:2/n:1000 451946 ns 413341 ns 1668
svd/m:5/n:1000 618475 ns 577213 ns 1169
svd/m:10/n:1000 907729 ns 863335 ns 808
svd/m:100/n:1000 9868543 ns 9116870 ns 76
svd/m:500/n:1000 156777811 ns 152042065 ns 5
svd/m:800/n:1000 429704070 ns 424677592 ns 2
svd/m:1000/n:1000 654864311 ns 642693162 ns 1
After:
------------------------------------------------------------
Benchmark Time CPU Iterations
------------------------------------------------------------
svd/m:1/n:1 265980 ns 245433 ns 2791
svd/m:2/n:1 340203 ns 302783 ns 2288
svd/m:5/n:1 337807 ns 301916 ns 2286
svd/m:10/n:1 338064 ns 302441 ns 2297
svd/m:100/n:1 335444 ns 298440 ns 2327
svd/m:500/n:1 338025 ns 302096 ns 2272
svd/m:800/n:1 328382 ns 291740 ns 2252
svd/m:1000/n:1 397494 ns 310905 ns 2239
svd/m:1/n:2 310464 ns 274507 ns 2535
svd/m:2/n:2 319999 ns 284247 ns 2515
svd/m:5/n:2 373435 ns 335919 ns 2069
svd/m:10/n:2 376327 ns 339327 ns 2056
svd/m:100/n:2 385061 ns 349258 ns 2003
svd/m:500/n:2 392352 ns 355735 ns 1932
svd/m:800/n:2 410736 ns 370677 ns 1881
svd/m:1000/n:2 494326 ns 405603 ns 1721
svd/m:1/n:5 316735 ns 277292 ns 2538
svd/m:2/n:5 383748 ns 342218 ns 2077
svd/m:5/n:5 494204 ns 454309 ns 1476
svd/m:10/n:5 547017 ns 508184 ns 1371
svd/m:100/n:5 514537 ns 476761 ns 1460
svd/m:500/n:5 544656 ns 504877 ns 1381
svd/m:800/n:5 642590 ns 599314 ns 1159
svd/m:1000/n:5 706166 ns 621209 ns 1106
svd/m:1/n:10 310825 ns 274374 ns 2511
svd/m:2/n:10 381316 ns 344202 ns 2094
svd/m:5/n:10 565469 ns 526759 ns 1266
svd/m:10/n:10 576111 ns 537286 ns 1299
svd/m:100/n:10 653250 ns 613392 ns 1137
svd/m:500/n:10 690532 ns 645828 ns 1080
svd/m:800/n:10 763924 ns 723677 ns 959
svd/m:1000/n:10 940342 ns 855517 ns 818
svd/m:1/n:100 306134 ns 271533 ns 2526
svd/m:2/n:100 374680 ns 339298 ns 2071
svd/m:5/n:100 576926 ns 539062 ns 1228
svd/m:10/n:100 656806 ns 615171 ns 1123
svd/m:100/n:100 3295164 ns 3138621 ns 223
svd/m:500/n:100 4269347 ns 4166000 ns 168
svd/m:800/n:100 4656541 ns 4522247 ns 154
svd/m:1000/n:100 6479223 ns 6354578 ns 112
svd/m:1/n:500 329966 ns 289083 ns 2440
svd/m:2/n:500 407535 ns 366794 ns 1947
svd/m:5/n:500 567367 ns 522809 ns 1336
svd/m:10/n:500 712307 ns 657608 ns 1065
svd/m:100/n:500 4262986 ns 4169907 ns 167
svd/m:500/n:500 28824720 ns 28650258 ns 25
svd/m:800/n:500 29330139 ns 28677269 ns 25
svd/m:1000/n:500 30848037 ns 30089216 ns 23
svd/m:1/n:800 328620 ns 289181 ns 2329
svd/m:2/n:800 419052 ns 379483 ns 1876
svd/m:5/n:800 587366 ns 546979 ns 1269
svd/m:10/n:800 830762 ns 787923 ns 893
svd/m:100/n:800 4763633 ns 4595738 ns 152
svd/m:500/n:800 30447861 ns 29949714 ns 24
svd/m:800/n:800 94188958 ns 93488372 ns 8
svd/m:1000/n:800 94701529 ns 93394677 ns 7
svd/m:1/n:1000 351102 ns 313099 ns 2218
svd/m:2/n:1000 446543 ns 407807 ns 1708
svd/m:5/n:1000 661152 ns 616174 ns 1129
svd/m:10/n:1000 915743 ns 873397 ns 802
svd/m:100/n:1000 6434730 ns 6282779 ns 113
svd/m:500/n:1000 30244321 ns 29684290 ns 24
svd/m:800/n:1000 92727423 ns 91477078 ns 8
svd/m:1000/n:1000 169500709 ns 168358420 ns 4
PiperOrigin-RevId: 582041508
2023-11-13 12:04:13 -08:00
jax authors
2aaa7559f9
Reverts a80d35f4eda99bac42a1573ff5029e7035ebe71b
...
PiperOrigin-RevId: 581285053
2023-11-10 09:32:29 -08:00
jax authors
a80d35f4ed
Replace gcc with clang compiler in CI nightly jobs.
...
PiperOrigin-RevId: 581086810
2023-11-09 18:30:03 -08:00
Jake VanderPlas
340e655ac2
Remove deprecated sym_pos
argument from jax.scipy.linalg.solve
...
PiperOrigin-RevId: 580940755
2023-11-09 09:53:37 -08:00
Skye Wanderman-Milne
55e3072d2e
Update versions and CHANGELOG after jax 0.4.20 release
2023-11-02 16:30:56 -07:00
Peter Hawkins
47a76df7cc
[IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
...
These types didn't match between xla::PrimitiveType and ifrt::DType.
Add a static_assert to enforce equality.
PiperOrigin-RevId: 576288042
2023-10-24 14:38:00 -07:00
jax authors
77de0df709
update change log's formatting.
...
PiperOrigin-RevId: 576214943
2023-10-24 11:52:24 -07:00
Peter Hawkins
8d49f9a159
Reverts 9b1a656c1ef1f93f9e93eccb662de9bebe66b51a
...
PiperOrigin-RevId: 576128882
2023-10-24 07:08:34 -07:00
jax authors
9b1a656c1e
Reverts 84c516974ab2b37169938d2d48a6c29a63c62c21
...
PiperOrigin-RevId: 575891656
2023-10-23 12:13:56 -07:00
jax authors
dde17cd5bc
Merge pull request #18180 from carlosgmartin:fill_diagonal
...
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
carlosgmartin
3cb504c583
Add jax.numpy.fill_diagonal.
2023-10-20 16:47:46 -04:00
Yash Katariya
613369fc22
Finish 0.4.19 jax and jaxlib release
...
PiperOrigin-RevId: 574983871
2023-10-19 13:27:52 -07:00
Peter Hawkins
d856ecc6fb
Set RPATH, not RUNPATH in JAX CUDA builds.
...
Fixes https://github.com/google/jax/issues/17497
2023-10-12 09:38:10 -07:00
Jake VanderPlas
117f4bdf9b
Define jax.typing.DTypeLike
2023-10-10 08:46:36 -07:00
Peter Hawkins
4611d13c07
Only perform compilation cache writes from process 0.
...
This avoids problems with contending writes on filesystems such as GCS.
PiperOrigin-RevId: 572032482
2023-10-09 13:55:07 -07:00
Skye Wanderman-Milne
a06beaa1a2
Update versions post jax 0.4.18 release
2023-10-06 17:20:34 -07:00
Skye Wanderman-Milne
d4a1bb9292
Update setup.py and CHANGELOG for jax 0.4.18 release
2023-10-06 13:13:33 -07:00
jax authors
f0e4ea23cc
Merge pull request #17987 from jakevdp:lax-dep
...
PiperOrigin-RevId: 571401660
2023-10-06 12:23:20 -07:00
Peter Hawkins
8c4d020db9
Improve CUDA install documentation.
...
Mention NCCL as a dependency, since it will be required by the next jaxlib release.
Mention LD_LIBRARY_PATH and PATH as how one overrides the CUDA installation for local installs.
Fixes #17831
2023-10-06 14:36:29 -04:00
Jake VanderPlas
ce6a0c43ad
jax.lax: deprecate inadvertent exports & internal utilities
2023-10-06 11:26:03 -07:00
Peter Hawkins
efc18e4147
[JAX] Obtain NCCL via a stub, rather than linking it statically or dynamically.
...
This shrinks the CUDA jaxlib wheel size by around 80MB.
PiperOrigin-RevId: 570554454
2023-10-03 18:33:58 -07:00
Skye Wanderman-Milne
82b58386b7
Update versions and CHANGELOG after jax 0.4.17 release
2023-10-03 17:54:35 -07:00
Jake VanderPlas
a09fdf6e2f
Add jax.numpy.bitwise_count()
2023-10-03 13:48:16 -07:00
Jake VanderPlas
9247a62b2b
Add CHANGELOG entry for the jnp annotation change
2023-10-02 11:31:28 -07:00
Peter Hawkins
b7dfde8d87
Add notes about the new CUDA version restrictions to the changelog and installation instructions.
2023-09-27 15:56:47 -04:00
Peter Hawkins
a2e1f1f24e
Update changelog.
...
Bump the minimum CUDA 12 pip package versions to the current releases.
2023-09-26 18:21:51 -04:00
Peter Hawkins
2fd6df45e4
Fix test failures under SciPy 1.11 for scipy.stats.mode.
2023-09-23 20:15:51 +00:00
Jake VanderPlas
243a6a236c
dtypes.issubdtype: validate a when b is dtypes.extended
2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f
[random] deprecate named key creation functions
2023-09-21 13:57:49 -07:00
Ayaka
74bc42e53e
Fix typo in CHANGELOG.md
2023-09-21 14:37:19 +08:00
Jake VanderPlas
024b1f23d7
Remove deprecated submodule jax.abstract_arrays
2023-09-19 15:40:18 -07:00
Yash Katariya
dcc465b4de
Finish jax and jaxlib 0.4.16 release
...
PiperOrigin-RevId: 566477931
2023-09-18 19:09:19 -07:00
Yash Katariya
a2720ee2c3
Deprecate jax.experimental.pjit.with_sharding_constraint
. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
...
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00
Roy Frostig
1f8cc44f4e
deprecate PRNGKeyArray.unsafe_raw_array
in favor of jax.random.key_data
...
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.
PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Jake VanderPlas
4e6c1b68c7
Deprecate random.KeyArray and random.PRNGKeyArray
2023-09-13 14:05:42 -07:00
Jake VanderPlas
eeb32a7d1f
Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped
...
PiperOrigin-RevId: 565142019
2023-09-13 13:21:46 -07:00
Jake VanderPlas
22ff7bd19a
Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
...
These have been deprecated in JAX following similar deprecations in numpy v1.25.0
PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
408c657436
Add a release note about a fixed Windows crash.
2023-09-07 09:35:25 -04:00
Jake VanderPlas
ca39457ea9
JEX: move jax.linear_util to jax.extend.linear_util
2023-08-30 18:32:12 -07:00