6627 Commits

Author SHA1 Message Date
Peter Hawkins
cb182b8b22 Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs.
The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions.

Timings on T4 GPU:

Before:

------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           263587 ns       242274 ns         2780
svd/m:2/n:1           335561 ns       298238 ns         2303
svd/m:5/n:1           337784 ns       299841 ns         2304
svd/m:10/n:1          339184 ns       300703 ns         2311
svd/m:100/n:1         359826 ns       320088 ns         2159
svd/m:500/n:1         376124 ns       338660 ns         2076
svd/m:800/n:1         375779 ns       335590 ns         2060
svd/m:1000/n:1        419171 ns       341487 ns         2072
svd/m:1/n:2           307564 ns       270663 ns         2544
svd/m:2/n:2           320928 ns       283601 ns         2487
svd/m:5/n:2           377373 ns       344228 ns         2035
svd/m:10/n:2          380557 ns       349412 ns         1953
svd/m:100/n:2         435465 ns       403496 ns         1722
svd/m:500/n:2         444610 ns       410913 ns         1680
svd/m:800/n:2         454493 ns       416495 ns         1665
svd/m:1000/n:2        492110 ns       420539 ns         1665
svd/m:1/n:5           307316 ns       275833 ns         2531
svd/m:2/n:5           374318 ns       341432 ns         2086
svd/m:5/n:5           512928 ns       470293 ns         1361
svd/m:10/n:5          589330 ns       537070 ns         1353
svd/m:100/n:5         620164 ns       580166 ns         1193
svd/m:500/n:5         636424 ns       593692 ns         1180
svd/m:800/n:5         635545 ns       595016 ns         1181
svd/m:1000/n:5        672443 ns       597387 ns         1115
svd/m:1/n:10          310013 ns       273998 ns         2520
svd/m:2/n:10          370451 ns       334489 ns         2105
svd/m:5/n:10          560037 ns       522223 ns         1274
svd/m:10/n:10         572868 ns       535388 ns         1304
svd/m:100/n:10        959802 ns       918258 ns          765
svd/m:500/n:10        955958 ns       909778 ns          758
svd/m:800/n:10        924104 ns       879512 ns          777
svd/m:1000/n:10       950140 ns       883493 ns          775
svd/m:1/n:100         351237 ns       315554 ns         2198
svd/m:2/n:100         426883 ns       390089 ns         1792
svd/m:5/n:100         601557 ns       564493 ns         1255
svd/m:10/n:100        920819 ns       880011 ns          787
svd/m:100/n:100      7902281 ns      7229220 ns           95
svd/m:500/n:100      9720727 ns      9040679 ns           79
svd/m:800/n:100      9856378 ns      8998050 ns           79
svd/m:1000/n:100     9721017 ns      9086414 ns           79
svd/m:1/n:500         371171 ns       334217 ns         2117
svd/m:2/n:500         449165 ns       411499 ns         1700
svd/m:5/n:500         620354 ns       581866 ns         1185
svd/m:10/n:500        892375 ns       847239 ns          833
svd/m:100/n:500      9564810 ns      8867540 ns           79
svd/m:500/n:500    111924035 ns    104078023 ns            7
svd/m:800/n:500    147777319 ns    142730412 ns            5
svd/m:1000/n:500   154205084 ns    149740209 ns            5
svd/m:1/n:800         372122 ns       334212 ns         2119
svd/m:2/n:800         456672 ns       419260 ns         1680
svd/m:5/n:800         691208 ns       626003 ns         1190
svd/m:10/n:800       1017694 ns       941480 ns          730
svd/m:100/n:800      9892683 ns      9091043 ns           76
svd/m:500/n:800    144134235 ns    139129722 ns            5
svd/m:800/n:800    342790246 ns    333299774 ns            2
svd/m:1000/n:800   432820082 ns    427978978 ns            2
svd/m:1/n:1000        372785 ns       335745 ns         1805
svd/m:2/n:1000        451946 ns       413341 ns         1668
svd/m:5/n:1000        618475 ns       577213 ns         1169
svd/m:10/n:1000       907729 ns       863335 ns          808
svd/m:100/n:1000     9868543 ns      9116870 ns           76
svd/m:500/n:1000   156777811 ns    152042065 ns            5
svd/m:800/n:1000   429704070 ns    424677592 ns            2
svd/m:1000/n:1000  654864311 ns    642693162 ns            1

After:
------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           265980 ns       245433 ns         2791
svd/m:2/n:1           340203 ns       302783 ns         2288
svd/m:5/n:1           337807 ns       301916 ns         2286
svd/m:10/n:1          338064 ns       302441 ns         2297
svd/m:100/n:1         335444 ns       298440 ns         2327
svd/m:500/n:1         338025 ns       302096 ns         2272
svd/m:800/n:1         328382 ns       291740 ns         2252
svd/m:1000/n:1        397494 ns       310905 ns         2239
svd/m:1/n:2           310464 ns       274507 ns         2535
svd/m:2/n:2           319999 ns       284247 ns         2515
svd/m:5/n:2           373435 ns       335919 ns         2069
svd/m:10/n:2          376327 ns       339327 ns         2056
svd/m:100/n:2         385061 ns       349258 ns         2003
svd/m:500/n:2         392352 ns       355735 ns         1932
svd/m:800/n:2         410736 ns       370677 ns         1881
svd/m:1000/n:2        494326 ns       405603 ns         1721
svd/m:1/n:5           316735 ns       277292 ns         2538
svd/m:2/n:5           383748 ns       342218 ns         2077
svd/m:5/n:5           494204 ns       454309 ns         1476
svd/m:10/n:5          547017 ns       508184 ns         1371
svd/m:100/n:5         514537 ns       476761 ns         1460
svd/m:500/n:5         544656 ns       504877 ns         1381
svd/m:800/n:5         642590 ns       599314 ns         1159
svd/m:1000/n:5        706166 ns       621209 ns         1106
svd/m:1/n:10          310825 ns       274374 ns         2511
svd/m:2/n:10          381316 ns       344202 ns         2094
svd/m:5/n:10          565469 ns       526759 ns         1266
svd/m:10/n:10         576111 ns       537286 ns         1299
svd/m:100/n:10        653250 ns       613392 ns         1137
svd/m:500/n:10        690532 ns       645828 ns         1080
svd/m:800/n:10        763924 ns       723677 ns          959
svd/m:1000/n:10       940342 ns       855517 ns          818
svd/m:1/n:100         306134 ns       271533 ns         2526
svd/m:2/n:100         374680 ns       339298 ns         2071
svd/m:5/n:100         576926 ns       539062 ns         1228
svd/m:10/n:100        656806 ns       615171 ns         1123
svd/m:100/n:100      3295164 ns      3138621 ns          223
svd/m:500/n:100      4269347 ns      4166000 ns          168
svd/m:800/n:100      4656541 ns      4522247 ns          154
svd/m:1000/n:100     6479223 ns      6354578 ns          112
svd/m:1/n:500         329966 ns       289083 ns         2440
svd/m:2/n:500         407535 ns       366794 ns         1947
svd/m:5/n:500         567367 ns       522809 ns         1336
svd/m:10/n:500        712307 ns       657608 ns         1065
svd/m:100/n:500      4262986 ns      4169907 ns          167
svd/m:500/n:500     28824720 ns     28650258 ns           25
svd/m:800/n:500     29330139 ns     28677269 ns           25
svd/m:1000/n:500    30848037 ns     30089216 ns           23
svd/m:1/n:800         328620 ns       289181 ns         2329
svd/m:2/n:800         419052 ns       379483 ns         1876
svd/m:5/n:800         587366 ns       546979 ns         1269
svd/m:10/n:800        830762 ns       787923 ns          893
svd/m:100/n:800      4763633 ns      4595738 ns          152
svd/m:500/n:800     30447861 ns     29949714 ns           24
svd/m:800/n:800     94188958 ns     93488372 ns            8
svd/m:1000/n:800    94701529 ns     93394677 ns            7
svd/m:1/n:1000        351102 ns       313099 ns         2218
svd/m:2/n:1000        446543 ns       407807 ns         1708
svd/m:5/n:1000        661152 ns       616174 ns         1129
svd/m:10/n:1000       915743 ns       873397 ns          802
svd/m:100/n:1000     6434730 ns      6282779 ns          113
svd/m:500/n:1000    30244321 ns     29684290 ns           24
svd/m:800/n:1000    92727423 ns     91477078 ns            8
svd/m:1000/n:1000  169500709 ns    168358420 ns            4
PiperOrigin-RevId: 582041508
2023-11-13 12:04:13 -08:00
jax authors
871b79925e Fix test failures when we update the abseil hashtable implementation.
PiperOrigin-RevId: 581988519
2023-11-13 09:24:06 -08:00
Peter Hawkins
1611e1bc41 Remove PythonJitTest from api_test.py.
Ever since the jit-pjit merge, the "Python" jit test has actually just called the same code as the "C++" jit test. We don't have a C++-free jit path any more. Remove the "Python" tests since they don't test anything.

PiperOrigin-RevId: 581965049
2023-11-13 08:03:23 -08:00
Junwhan Ahn
55394a0914 Roll back the optimized version of jax.block_until_ready due to test breakage
Reverts 6cc6d093643c0265c7de4027f79879f6945e0342

PiperOrigin-RevId: 581577789
2023-11-11 12:15:45 -08:00
jax authors
45982c8439 Update test since C PjRt API topology description is available.
The compilation_cache_test had an exclusion since the C PjRt
topology description had not been implemented. Now that it is
available, remove the exclusion.

PiperOrigin-RevId: 581396824
2023-11-10 16:12:06 -08:00
Jieying Luo
11236dbe34 Disable profiler test for older plugins.
PiperOrigin-RevId: 581391435
2023-11-10 15:58:20 -08:00
Jake VanderPlas
a9452b98a3 jnp.vectorize: support None arguments 2023-11-10 14:25:42 -08:00
Jake VanderPlas
c0f3fa00f8 [random] support key dtype in custom_jvp
To do this, we introduce a dtype for key tangents which cannot be used
to generate random values
2023-11-10 11:16:23 -08:00
Junwhan Ahn
6cc6d09364 Implement more efficient jax.block_until_ready(x) in C++
The current implementation synchronously calls `ArrayImpl.block_until_ready()` one by one. This is suboptimal when it's not cheap to query the readiness of an array. Also, calling `x.block_until_ready()` causes GIL to be acquired/released repeatedly.

To address this issue, this CL introduces a C++ implementation of `jax.block_until_ready(x)` that uses IFRT's `Array::GetReadyFuture()` to asynchronously query the readiness of all arrays and wait for them once. To preserve the previous behavior, the C++ implementation also has a slow path for any non-PyArray objects that implement `block_until_ready`.

PiperOrigin-RevId: 581302290
2023-11-10 10:34:34 -08:00
jax authors
fc6ed3bc68 Merge pull request #18463 from jakevdp:eye-offset
PiperOrigin-RevId: 581020737
2023-11-09 14:11:09 -08:00
Jieying Luo
3f1900e2e3 [PJRT C API] Add a util method to get the PJRT C API version of the backend.
Disable some memories tests which are not supported on plugin older than 0.32.

PiperOrigin-RevId: 581008059
2023-11-09 13:30:19 -08:00
Yash Katariya
cf3c041366 Disable jax memories flag.
PiperOrigin-RevId: 580961421
2023-11-09 10:54:02 -08:00
Jake VanderPlas
4dd6334265 jnp.eye: handle larger-than int32 offsets 2023-11-09 10:23:49 -08:00
Sharad Vikram
8fbcfce2dd [Pallas] Enable interpreter mode as default lowering for CPU
PiperOrigin-RevId: 580700740
2023-11-08 16:35:31 -08:00
jax authors
62e8f4d4aa Merge pull request #18431 from jakevdp:hist-ravel
PiperOrigin-RevId: 580597288
2023-11-08 11:37:15 -08:00
jax authors
6efcfe8fe0 Merge pull request #18386 from shacklettbp:pallas
PiperOrigin-RevId: 580561989
2023-11-08 09:48:27 -08:00
Jake VanderPlas
a30d51ba2e jnp.histogram: avoid flattening input 2023-11-08 08:55:09 -08:00
Peter Hawkins
f4eb3f6d86 Add a regression test for a pmap issue that is fixed at head.
Fixes https://github.com/google/jax/issues/5757

PiperOrigin-RevId: 580243825
2023-11-07 11:21:21 -08:00
Jake VanderPlas
17a26235e6 Make jax.scipy.optimize test compatible with upstream scipy 2023-11-06 14:10:24 -08:00
jax authors
7e372944f9 Fix the missing cache_misses metric when min compile time is set to zero.
Remove the code which checks if the min compile time is greater than zero. After this change, we can catch cache_misses when min compile time is zero.

Testing: revised unit test.
PiperOrigin-RevId: 579951415
2023-11-06 14:04:35 -08:00
jax authors
f54be7aa7c Merge pull request #11625 from jakevdp:deprecate-shuffle
PiperOrigin-RevId: 579933463
2023-11-06 13:04:42 -08:00
Jake VanderPlas
5f7335fb55 Deprecate jax.random.shuffle
This has been long deprecated, but this PR uses the standard deprecation
framework to make it easier to finalize.
2023-11-06 12:21:56 -08:00
jax authors
29a6262d11 Merge pull request #18380 from jakevdp:cross-2d
PiperOrigin-RevId: 579909035
2023-11-06 11:45:36 -08:00
Ben West
02f6fcb9da Add beta function 2023-11-05 15:37:38 -08:00
Tamás Danyluk
bfbf9e1c33 [XLA:GPU] Consider Triton for all non-pure GEMM fusions
This is a big step toward enabling xla_gpu_triton_gemm_any by default.

It shows about 1.05x geomean speedup on internal benchmarks (comparable to xla_gpu_triton_gemm_any=true).

PiperOrigin-RevId: 579524573
2023-11-04 16:05:19 -07:00
Jake VanderPlas
96d9f89415 [random] better errors for unsupported operations on prng keys 2023-11-03 19:23:18 -07:00
Brennan Shacklett
094579910f [Pallas]: Fix kernel grid dimensions that are too large in Y and Z 2023-11-03 17:17:42 -07:00
Jake VanderPlas
4f863e9148 jnp.cross: account for numpy 2.0 deprecation 2023-11-03 14:15:23 -07:00
jax authors
e227536fd6 In api_test.py, wait for the result in test_double_donation.
PiperOrigin-RevId: 579267104
2023-11-03 12:23:55 -07:00
Peter Hawkins
011d49c518 Add a test for double donation.
The underlying issue was fixed some time ago.

Fixes https://github.com/google/jax/issues/9635

PiperOrigin-RevId: 579170638
2023-11-03 07:03:13 -07:00
jax authors
db07f40233 Fall-back to original device/backend hashing if topology-desc is unavailable.
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.

Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
2023-11-02 18:43:48 -07:00
Jieying Luo
c9db50cfd0 Enable python_callback_test for stream executor.
python_callback_test is supported for GPU stream executor. TPU stream executor was deprecated.

PiperOrigin-RevId: 578960299
2023-11-02 13:26:59 -07:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
Reed Wanderman-Milne
d41078fb95 Properly pack and unpack int4 arrays on CPU in PJRT.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.

PiperOrigin-RevId: 578692796
2023-11-01 17:39:24 -07:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Yash Katariya
85af862efd [Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time.
Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035

PiperOrigin-RevId: 577974380
2023-10-30 15:28:43 -07:00
Sergei Lebedev
fd3a8b2cc6 Deprecated define_* and DEFINE_* methods on jax.config
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
2023-10-29 20:58:19 +00:00
Yunlong Liu
b99958db37 Places the remat decorator on top of the body function.
PiperOrigin-RevId: 577320028
2023-10-27 15:27:19 -07:00
jax authors
9ba305cced Invalidate in-memory caches on XLA-AutoFDO profile version change.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.

Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
  the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
2023-10-27 12:52:57 -07:00
Skye Wanderman-Milne
58c86064f6 [PJRT:C] Implement PjRtCApiClient::GetTopologyDescription
PiperOrigin-RevId: 577249826
2023-10-27 11:03:04 -07:00
jax authors
11c4e2c820 [JAX] Add an option subset_by_index that allows computing a contiguous subset of eigenvalues from eigh.
PiperOrigin-RevId: 577222219
2023-10-27 09:29:33 -07:00
Yash Katariya
4d15375596 AOT sharding mismatch error shouldn't have GSPMDSharding in it.
PiperOrigin-RevId: 576668290
2023-10-25 15:48:01 -07:00
Peter Hawkins
47a76df7cc [IFRT] Fix incorrect type numbers for e4m3 and e5m2 types.
These types didn't match between xla::PrimitiveType and ifrt::DType.

Add a static_assert to enforce equality.

PiperOrigin-RevId: 576288042
2023-10-24 14:38:00 -07:00
Peter Hawkins
e7f1d29716 Relax some test tolerances for TPU.
PiperOrigin-RevId: 576192162
2023-10-24 10:45:40 -07:00
George Necula
9bc04393b2 Disable flaky python callback test.
PiperOrigin-RevId: 575893965
2023-10-23 12:24:05 -07:00
jax authors
1498865bbe Remove temporary patch for where broadcasting test
PiperOrigin-RevId: 575859446
2023-10-23 10:37:25 -07:00
jax authors
dde17cd5bc Merge pull request #18180 from carlosgmartin:fill_diagonal
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
jax authors
7fdc06fa18 Merge pull request #17783 from gnecula:export_effects
PiperOrigin-RevId: 575310727
2023-10-20 13:55:22 -07:00
carlosgmartin
3cb504c583 Add jax.numpy.fill_diagonal. 2023-10-20 16:47:46 -04:00
George Necula
70f6a9e725 [export] Add support for exporting functions with effects
In presence of ordered effects JAX lowering produces a main
function that takes token
inputs and returns token outputs. Previously, when exporting
such a module, we would wrap the main function with a function
that does not use tokens on inputs and outputs. With this
change we actually leave the token inputs and outputs and
rely on consumers of the exported function to know how to
invoke a function with tokens.

Due to the fact that PJRT does not support passing tokens
as input and output to the top-level function, JAX native
lowering uses dummy bool[0] arrays in lieu of tokens for
the top-level function, and uses stablehlo tokens for the
inner functions. When we export a function for serialization
we want to use stablehlo tokens even at top-level, to enable
calling that function from a larger JAX computation later.

See more details about the calling convention in the
docstring for `export.export`.

We also fix and test multi-platform lowering in presence
of effects.

This introduces serialization version 9, but does not change the
default serialization version. This means that version 9 will not
be used except in tests that specifically override the
serialization version.
2023-10-20 22:27:27 +02:00