496 Commits

Author SHA1 Message Date
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
3195a069ef [shape_poly] Improved the tests for inequality comparisons.
Added more tests and broke some large tests into smaller ones.
2024-01-08 08:39:28 +02:00
George Necula
cd0e10f29b [shape_poly] Simplify and speed-up the __eq__ functions for symbolic expressions
Equality is used heavily for symbolic expressions because we use them
as dictionary keys or in sets. Previously, we used a more complete
and more expensive form of equality where we attempted to prove that
"e1 - e2 >= 0" and "e1 - e2 <= 0". This is an overkill and none
of the tests we have so far rely on this power. Now we just
normalize "e1 - e2" and if it reduces syntactically to an integer
we check if the integer is 0. If the difference does not reduce
to an integer we say that the expressions are disequal.

This may possibly change user-visible behavior when it depends
on the outcome of equality comparisons of symbolic dimensions
in presence of shape polymorphism.
2024-01-07 13:18:18 +02:00
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
Jake VanderPlas
47e5c81a2c jnp.ndarray.item(): add args support 2024-01-03 13:03:47 -08:00
Jake VanderPlas
c06e186f60 Error on conversion of empty arrays to boolean.
PiperOrigin-RevId: 595264332
2024-01-02 19:26:45 -08:00
Jake VanderPlas
fff5ea579a Remove deprecated unsafe_raw_array method from PRNG keys
PiperOrigin-RevId: 595190146
2024-01-02 13:03:21 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Sergei Lebedev
41531123f4 Rolling back #18980, because it is not backwards compatible and breaks existing users.
Reverts 91faddd023c2df77df310f3f2f17eb2fa1e60df0

PiperOrigin-RevId: 591200403
2023-12-15 03:24:01 -08:00
George Necula
fd0f007765 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1 = export.deserialized(ser)
export.call(exp1)
```

This change also includes a workaround to allow users to still
do `from jax.experimental.export import export`, for a while.
2023-12-15 10:00:05 +02:00
Yash Katariya
8bf3a86860 [roll forward 2] Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
Reverts b52bcc1639368069075284eefc763f824ca155f1

PiperOrigin-RevId: 590959383
2023-12-14 09:14:25 -08:00
Yash Katariya
6e1ab7ca3f Finish release of jax and jaxlib 0.4.23
PiperOrigin-RevId: 590833947
2023-12-13 23:39:08 -08:00
Peter Hawkins
b392622647 Add patch to suppress XLA:GPU logging.
PiperOrigin-RevId: 590780227
2023-12-13 18:53:50 -08:00
Yash Katariya
25c16c0b78 Finish jax and jaxlib 0.4.22 release
PiperOrigin-RevId: 590775311
2023-12-13 18:26:47 -08:00
Yash Katariya
b52bcc1639 Reverts 3c07c10a9a55f9a32390dd10cf3f420bdf3f1ed8
PiperOrigin-RevId: 590700623
2023-12-13 13:45:14 -08:00
Yash Katariya
3c07c10a9a Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
PiperOrigin-RevId: 590684939
2023-12-13 12:48:48 -08:00
Jake VanderPlas
35b84402c0 Deprecate arr.device_buffer and arr.device_buffers 2023-12-06 10:20:29 -08:00
Yash Katariya
a9bfbd32e1 Finish jax and jaxlib 0.4.21 release
PiperOrigin-RevId: 587866580
2023-12-04 15:51:58 -08:00
Yash Katariya
f0bc7e0fc6 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e
PiperOrigin-RevId: 587231484
2023-12-01 22:06:33 -08:00
jax authors
8ad774fb10 Automate arguments for jax.distributed.initialize for cloud TPU environments.
PiperOrigin-RevId: 586892544
2023-11-30 22:25:00 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Peter Hawkins
49c80e68d1 Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
jax authors
dce6ab4548 Reverts 2aaa7559f96e4bb7b0271665bf386bf3ba22c451
PiperOrigin-RevId: 584033001
2023-11-20 08:23:40 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jake VanderPlas
84aa7e5c53 Deprecate passing of None to jax.numpy.array 2023-11-16 15:10:56 -08:00
Peter Hawkins
234be736c4 Reverts ef9075159a67a2b94526b65e4a2c2904a4a49046
PiperOrigin-RevId: 582789416
2023-11-15 13:35:52 -08:00
carlosgmartin
9f8e1bc34a Add nn.squareplus. 2023-11-14 23:52:41 -05:00
Peter Hawkins
ef9075159a Reverts 6401db3775bace69989cd76ccd328fc9a6cf0964
PiperOrigin-RevId: 582275667
2023-11-14 04:31:54 -08:00
Peter Hawkins
6401db3775 Make the CPU backend participate in distributed initialization.
The main effect of this change is that CPU devices end up with a unique global ID and the correct process index.

PiperOrigin-RevId: 582127068
2023-11-13 16:55:12 -08:00
Peter Hawkins
cb182b8b22 Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs.
The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions.

Timings on T4 GPU:

Before:

------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           263587 ns       242274 ns         2780
svd/m:2/n:1           335561 ns       298238 ns         2303
svd/m:5/n:1           337784 ns       299841 ns         2304
svd/m:10/n:1          339184 ns       300703 ns         2311
svd/m:100/n:1         359826 ns       320088 ns         2159
svd/m:500/n:1         376124 ns       338660 ns         2076
svd/m:800/n:1         375779 ns       335590 ns         2060
svd/m:1000/n:1        419171 ns       341487 ns         2072
svd/m:1/n:2           307564 ns       270663 ns         2544
svd/m:2/n:2           320928 ns       283601 ns         2487
svd/m:5/n:2           377373 ns       344228 ns         2035
svd/m:10/n:2          380557 ns       349412 ns         1953
svd/m:100/n:2         435465 ns       403496 ns         1722
svd/m:500/n:2         444610 ns       410913 ns         1680
svd/m:800/n:2         454493 ns       416495 ns         1665
svd/m:1000/n:2        492110 ns       420539 ns         1665
svd/m:1/n:5           307316 ns       275833 ns         2531
svd/m:2/n:5           374318 ns       341432 ns         2086
svd/m:5/n:5           512928 ns       470293 ns         1361
svd/m:10/n:5          589330 ns       537070 ns         1353
svd/m:100/n:5         620164 ns       580166 ns         1193
svd/m:500/n:5         636424 ns       593692 ns         1180
svd/m:800/n:5         635545 ns       595016 ns         1181
svd/m:1000/n:5        672443 ns       597387 ns         1115
svd/m:1/n:10          310013 ns       273998 ns         2520
svd/m:2/n:10          370451 ns       334489 ns         2105
svd/m:5/n:10          560037 ns       522223 ns         1274
svd/m:10/n:10         572868 ns       535388 ns         1304
svd/m:100/n:10        959802 ns       918258 ns          765
svd/m:500/n:10        955958 ns       909778 ns          758
svd/m:800/n:10        924104 ns       879512 ns          777
svd/m:1000/n:10       950140 ns       883493 ns          775
svd/m:1/n:100         351237 ns       315554 ns         2198
svd/m:2/n:100         426883 ns       390089 ns         1792
svd/m:5/n:100         601557 ns       564493 ns         1255
svd/m:10/n:100        920819 ns       880011 ns          787
svd/m:100/n:100      7902281 ns      7229220 ns           95
svd/m:500/n:100      9720727 ns      9040679 ns           79
svd/m:800/n:100      9856378 ns      8998050 ns           79
svd/m:1000/n:100     9721017 ns      9086414 ns           79
svd/m:1/n:500         371171 ns       334217 ns         2117
svd/m:2/n:500         449165 ns       411499 ns         1700
svd/m:5/n:500         620354 ns       581866 ns         1185
svd/m:10/n:500        892375 ns       847239 ns          833
svd/m:100/n:500      9564810 ns      8867540 ns           79
svd/m:500/n:500    111924035 ns    104078023 ns            7
svd/m:800/n:500    147777319 ns    142730412 ns            5
svd/m:1000/n:500   154205084 ns    149740209 ns            5
svd/m:1/n:800         372122 ns       334212 ns         2119
svd/m:2/n:800         456672 ns       419260 ns         1680
svd/m:5/n:800         691208 ns       626003 ns         1190
svd/m:10/n:800       1017694 ns       941480 ns          730
svd/m:100/n:800      9892683 ns      9091043 ns           76
svd/m:500/n:800    144134235 ns    139129722 ns            5
svd/m:800/n:800    342790246 ns    333299774 ns            2
svd/m:1000/n:800   432820082 ns    427978978 ns            2
svd/m:1/n:1000        372785 ns       335745 ns         1805
svd/m:2/n:1000        451946 ns       413341 ns         1668
svd/m:5/n:1000        618475 ns       577213 ns         1169
svd/m:10/n:1000       907729 ns       863335 ns          808
svd/m:100/n:1000     9868543 ns      9116870 ns           76
svd/m:500/n:1000   156777811 ns    152042065 ns            5
svd/m:800/n:1000   429704070 ns    424677592 ns            2
svd/m:1000/n:1000  654864311 ns    642693162 ns            1

After:
------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           265980 ns       245433 ns         2791
svd/m:2/n:1           340203 ns       302783 ns         2288
svd/m:5/n:1           337807 ns       301916 ns         2286
svd/m:10/n:1          338064 ns       302441 ns         2297
svd/m:100/n:1         335444 ns       298440 ns         2327
svd/m:500/n:1         338025 ns       302096 ns         2272
svd/m:800/n:1         328382 ns       291740 ns         2252
svd/m:1000/n:1        397494 ns       310905 ns         2239
svd/m:1/n:2           310464 ns       274507 ns         2535
svd/m:2/n:2           319999 ns       284247 ns         2515
svd/m:5/n:2           373435 ns       335919 ns         2069
svd/m:10/n:2          376327 ns       339327 ns         2056
svd/m:100/n:2         385061 ns       349258 ns         2003
svd/m:500/n:2         392352 ns       355735 ns         1932
svd/m:800/n:2         410736 ns       370677 ns         1881
svd/m:1000/n:2        494326 ns       405603 ns         1721
svd/m:1/n:5           316735 ns       277292 ns         2538
svd/m:2/n:5           383748 ns       342218 ns         2077
svd/m:5/n:5           494204 ns       454309 ns         1476
svd/m:10/n:5          547017 ns       508184 ns         1371
svd/m:100/n:5         514537 ns       476761 ns         1460
svd/m:500/n:5         544656 ns       504877 ns         1381
svd/m:800/n:5         642590 ns       599314 ns         1159
svd/m:1000/n:5        706166 ns       621209 ns         1106
svd/m:1/n:10          310825 ns       274374 ns         2511
svd/m:2/n:10          381316 ns       344202 ns         2094
svd/m:5/n:10          565469 ns       526759 ns         1266
svd/m:10/n:10         576111 ns       537286 ns         1299
svd/m:100/n:10        653250 ns       613392 ns         1137
svd/m:500/n:10        690532 ns       645828 ns         1080
svd/m:800/n:10        763924 ns       723677 ns          959
svd/m:1000/n:10       940342 ns       855517 ns          818
svd/m:1/n:100         306134 ns       271533 ns         2526
svd/m:2/n:100         374680 ns       339298 ns         2071
svd/m:5/n:100         576926 ns       539062 ns         1228
svd/m:10/n:100        656806 ns       615171 ns         1123
svd/m:100/n:100      3295164 ns      3138621 ns          223
svd/m:500/n:100      4269347 ns      4166000 ns          168
svd/m:800/n:100      4656541 ns      4522247 ns          154
svd/m:1000/n:100     6479223 ns      6354578 ns          112
svd/m:1/n:500         329966 ns       289083 ns         2440
svd/m:2/n:500         407535 ns       366794 ns         1947
svd/m:5/n:500         567367 ns       522809 ns         1336
svd/m:10/n:500        712307 ns       657608 ns         1065
svd/m:100/n:500      4262986 ns      4169907 ns          167
svd/m:500/n:500     28824720 ns     28650258 ns           25
svd/m:800/n:500     29330139 ns     28677269 ns           25
svd/m:1000/n:500    30848037 ns     30089216 ns           23
svd/m:1/n:800         328620 ns       289181 ns         2329
svd/m:2/n:800         419052 ns       379483 ns         1876
svd/m:5/n:800         587366 ns       546979 ns         1269
svd/m:10/n:800        830762 ns       787923 ns          893
svd/m:100/n:800      4763633 ns      4595738 ns          152
svd/m:500/n:800     30447861 ns     29949714 ns           24
svd/m:800/n:800     94188958 ns     93488372 ns            8
svd/m:1000/n:800    94701529 ns     93394677 ns            7
svd/m:1/n:1000        351102 ns       313099 ns         2218
svd/m:2/n:1000        446543 ns       407807 ns         1708
svd/m:5/n:1000        661152 ns       616174 ns         1129
svd/m:10/n:1000       915743 ns       873397 ns          802
svd/m:100/n:1000     6434730 ns      6282779 ns          113
svd/m:500/n:1000    30244321 ns     29684290 ns           24
svd/m:800/n:1000    92727423 ns     91477078 ns            8
svd/m:1000/n:1000  169500709 ns    168358420 ns            4
PiperOrigin-RevId: 582041508
2023-11-13 12:04:13 -08:00
jax authors
2aaa7559f9 Reverts a80d35f4eda99bac42a1573ff5029e7035ebe71b
PiperOrigin-RevId: 581285053
2023-11-10 09:32:29 -08:00
jax authors
a80d35f4ed Replace gcc with clang compiler in CI nightly jobs.
PiperOrigin-RevId: 581086810
2023-11-09 18:30:03 -08:00
Jake VanderPlas
340e655ac2 Remove deprecated sym_pos argument from jax.scipy.linalg.solve
PiperOrigin-RevId: 580940755
2023-11-09 09:53:37 -08:00
Skye Wanderman-Milne
55e3072d2e Update versions and CHANGELOG after jax 0.4.20 release 2023-11-02 16:30:56 -07:00
Peter Hawkins
47a76df7cc [IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
These types didn't match between xla::PrimitiveType and ifrt::DType.

Add a static_assert to enforce equality.

PiperOrigin-RevId: 576288042
2023-10-24 14:38:00 -07:00
jax authors
77de0df709 update change log's formatting.
PiperOrigin-RevId: 576214943
2023-10-24 11:52:24 -07:00
Peter Hawkins
8d49f9a159 Reverts 9b1a656c1ef1f93f9e93eccb662de9bebe66b51a
PiperOrigin-RevId: 576128882
2023-10-24 07:08:34 -07:00
jax authors
9b1a656c1e Reverts 84c516974ab2b37169938d2d48a6c29a63c62c21
PiperOrigin-RevId: 575891656
2023-10-23 12:13:56 -07:00
jax authors
dde17cd5bc Merge pull request #18180 from carlosgmartin:fill_diagonal
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
Yash Katariya
613369fc22 Finish 0.4.19 jax and jaxlib release
PiperOrigin-RevId: 574983871
2023-10-19 13:27:52 -07:00
Peter Hawkins
d856ecc6fb Set RPATH, not RUNPATH in JAX CUDA builds.
Fixes https://github.com/google/jax/issues/17497
2023-10-12 09:38:10 -07:00
Jake VanderPlas
117f4bdf9b Define jax.typing.DTypeLike 2023-10-10 08:46:36 -07:00
Peter Hawkins
4611d13c07 Only perform compilation cache writes from process 0.
This avoids problems with contending writes on filesystems such as GCS.

PiperOrigin-RevId: 572032482
2023-10-09 13:55:07 -07:00
Skye Wanderman-Milne
a06beaa1a2 Update versions post jax 0.4.18 release 2023-10-06 17:20:34 -07:00
Skye Wanderman-Milne
d4a1bb9292 Update setup.py and CHANGELOG for jax 0.4.18 release 2023-10-06 13:13:33 -07:00