Kuangyuan Chen
b764aadbcf
Add pmap benchmarks using jax.Array
...
PiperOrigin-RevId: 473325120
2022-09-09 13:10:42 -07:00
Kuangyuan Chen
d17e516ea7
Add benchmarks for jax.Array
...
PiperOrigin-RevId: 471889808
2022-09-02 14:45:09 -07:00
Matthew Johnson
b3b4ffbc21
make slicing compilation a little faster
...
For the common special case of int or slice-of-int/None indexing, we can
generate a lax.slice rather than a lax.gather. That makes compilation a little
faster, and makes the generated jaxpr a bit more wieldy too, to process in
transformations and to read when pretty-printed.
2022-08-19 09:14:37 -07:00
Matthew Johnson
42dd7cac43
simplify slicing jaxprs a little
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Matthew Johnson
e3a92d52ba
prepare to switch to new remat
...
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
* add benchmarks for new remat eager performance, and some caching to make those
benchmarks fast
* warn when the old-remat-exclusive `concrete` feature is used, with an
actionable message pointing to the new recommended approach involving static_argnums
* add the static_argnums parameter to both new and old remt
* update docstrings (and de-duplicate them to)
* add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
Yash Katariya
2109c6ec8c
Make is_compatible_aval
an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
...
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
Yash Katariya
47623264db
Export HloSharding
via pybind which is a C++ wrapper around OpSharding proto.
...
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
Matthew Johnson
148173630f
add an optional fastpath for api_util.shaped_abstractify
...
also add a benchmark for it, 8.7ms -> 0.2ms on my machine
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
Jake VanderPlas
a10f0377db
Avoid top-level aliases of jax.tree_util.*
2022-07-07 11:41:02 -07:00
Qiao Zhang
5d7f639769
Add small and big matmul to api_benchmarks.
...
name cpu/op
jit_small_matmul 2.96µs ± 2%
jit_big_matmul 22.1µs ±21%
name time/op
jit_small_matmul 2.96µs ± 2%
jit_big_matmul 22.7µs ±21%
PiperOrigin-RevId: 435453853
2022-03-17 14:51:17 -07:00
Tianjian Lu
6a0bfdd637
[JAX] Requires indices to be sorted and of int32 in _sparse_bcoo_matvec
test.
...
PiperOrigin-RevId: 417695937
2021-12-21 14:56:52 -08:00
Jake VanderPlas
b8296bb06c
benchmarks: add compilation of sparse operations
...
PiperOrigin-RevId: 405752775
2021-10-26 15:41:21 -07:00
Jake VanderPlas
c62452f2d2
benchmarks: add JIT versions of sparse.BCOO benchmarks
...
PiperOrigin-RevId: 405696495
2021-10-26 11:39:01 -07:00
Jake VanderPlas
5a3733788c
benchmarks: add sparse.BCOO todense/fromdense/matvec benchmarks
...
PiperOrigin-RevId: 405519824
2021-10-25 16:44:36 -07:00
Yash Katariya
bfbdfa87e7
Add a warmup loop to pmap_simple_8_devices_100_args benchmark so as to not measure the compile time.
...
PiperOrigin-RevId: 401402336
2021-10-06 19:51:35 -07:00
Peter Hawkins
6a1b626564
Remove jax.api.
...
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00
Jean-Baptiste Lespiau
6cb8737c1a
Add a benchmark with many arguments.
...
PiperOrigin-RevId: 393216026
2021-08-26 15:06:09 -07:00
Jean-Baptiste Lespiau
afdd195e42
Internal only.
...
PiperOrigin-RevId: 388562206
2021-08-03 15:49:17 -07:00
Qiao Zhang
850bd66242
[JAX] Prune unused inputs in jit.
...
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args
PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00
Tom Hennigan
9d56552517
Remove special casing on npy_value when indexing sharded arrays.
...
Before:
```
GPU
sda_index_2 2912972 ns 2778716 ns 256
TPU
sda_index_1 769968 ns 751700 ns 921
sda_index_2 1510841 ns 1489716 ns 465
sda_index_8 6102259 ns 6027655 ns 117
```
After:
```
GPU
sda_index_2 28095 ns 27983 ns 25463
TPU
sda_index_1 10302 ns 10279 ns 67884
sda_index_2 20010 ns 19947 ns 34628
sda_index_8 78492 ns 78306 ns 8934
```
PiperOrigin-RevId: 368380864
2021-04-14 01:17:10 -07:00
Matthew Johnson
9802f3378e
add simple single-primitive eager benchmarks
2021-03-18 21:46:46 -07:00
Peter Hawkins
cdd36b1113
Improve API benchmarks.
...
Add benchmarks for different dispatch arg arities.
Add more blocking before and after benchmark loops that don't otherwise block.
2021-03-03 20:50:45 -05:00
Tom Hennigan
7b94c44af9
Sort imports.
...
PiperOrigin-RevId: 359532656
2021-02-25 08:52:09 -08:00
Jean-Baptiste Lespiau
e95d5701e3
Add benchmarks for specifically the dispatch time. ( #4128 )
...
The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result.
We also add one to distinguish the time it takes to call the function with the argument transfer or without it.
e.g.
name time/op
jit_trivial_dispatch 28.9µs ± 2%
jit_trivial 31.5µs ± 5%
jit_simple_dispatch 60.7µs ± 4%
jit_simple 129µs ±24%
jit_simple_many_args_disptch 390µs ±19%
jit_simple_many_args 388µs ±16%
jit_dispatch_without_transfer 379µs ± 6%
jit_dispatch_with_transfer 450µs ± 5%
2020-08-27 17:02:13 +03:00
Chris Jones
c1aeb8b3fe
Add simple JAX API microbenchmarks. ( #3679 )
2020-07-09 10:02:23 -07:00