9 Commits

Author SHA1 Message Date
Jean-Baptiste Lespiau
6cb8737c1a Add a benchmark with many arguments.
PiperOrigin-RevId: 393216026
2021-08-26 15:06:09 -07:00
Jean-Baptiste Lespiau
afdd195e42 Internal only.
PiperOrigin-RevId: 388562206
2021-08-03 15:49:17 -07:00
Qiao Zhang
850bd66242 [JAX] Prune unused inputs in jit.
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args

PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00
Tom Hennigan
9d56552517 Remove special casing on npy_value when indexing sharded arrays.
Before:

```
GPU
sda_index_2                           2912972 ns      2778716 ns          256

TPU
sda_index_1                            769968 ns       751700 ns          921
sda_index_2                           1510841 ns      1489716 ns          465
sda_index_8                           6102259 ns      6027655 ns          117
```

After:

```
GPU
sda_index_2                             28095 ns        27983 ns        25463

TPU
sda_index_1                             10302 ns        10279 ns        67884
sda_index_2                             20010 ns        19947 ns        34628
sda_index_8                             78492 ns        78306 ns         8934
```

PiperOrigin-RevId: 368380864
2021-04-14 01:17:10 -07:00
Matthew Johnson
9802f3378e add simple single-primitive eager benchmarks 2021-03-18 21:46:46 -07:00
Peter Hawkins
cdd36b1113 Improve API benchmarks.
Add benchmarks for different dispatch arg arities.

Add more blocking before and after benchmark loops that don't otherwise block.
2021-03-03 20:50:45 -05:00
Tom Hennigan
7b94c44af9 Sort imports.
PiperOrigin-RevId: 359532656
2021-02-25 08:52:09 -08:00
Jean-Baptiste Lespiau
e95d5701e3
Add benchmarks for specifically the dispatch time. (#4128)
The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result.
We also add one to distinguish the time it takes to call the function with the argument transfer or without it.

e.g.

name                                   time/op
jit_trivial_dispatch                   28.9µs ± 2%
jit_trivial                            31.5µs ± 5%
jit_simple_dispatch                    60.7µs ± 4%
jit_simple                              129µs ±24%
jit_simple_many_args_disptch            390µs ±19%
jit_simple_many_args                    388µs ±16%
jit_dispatch_without_transfer           379µs ± 6%
jit_dispatch_with_transfer              450µs ± 5%
2020-08-27 17:02:13 +03:00
Chris Jones
c1aeb8b3fe
Add simple JAX API microbenchmarks. (#3679) 2020-07-09 10:02:23 -07:00