131 Commits

Author SHA1 Message Date
Daniel Suo
c76b8adcf5 Add Jax tracing micro benchmarks.
Add a first benchmark for tracing/lowering pallas splash attention.

Sample results below taken on a GCP n2d-standard-128 instance with 512GB Ram and 128 vCPU AMD EPYC Milan.

---------------------------------------------------------------------------------
Benchmark                                       Time             CPU   Iterations
---------------------------------------------------------------------------------
test_pallas_mqa_splash_attention_trace       39.8 ms         39.8 ms           19
test_pallas_mqa_splash_attention_lower       42.1 ms         41.9 ms           18

PiperOrigin-RevId: 742259409
2025-04-09 18:05:14 +00:00
Hyeontaek Lim
73b8f6aee2 [JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant
To prepare for the upcoming `BatchedDevicePut` implementation changes, this
change makes `make_array_from_callback_*` benchmark code to be more
homogeneous. Also it adds a variant that uses a partially replicated sharding.

PiperOrigin-RevId: 736665856
2025-03-13 15:50:46 -07:00
Hyeontaek Lim
178278863d [JAX] Fix api_benchmark broken by https://github.com/jax-ml/jax/pull/26569
`pjit_check_aval_sharding` expects `names: Sequence[str]`.

PiperOrigin-RevId: 734614264
2025-03-07 10:49:53 -08:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
George Necula
0831e2e340 [shape_poly] Adding shape polymorphism support for the state primitives. 2024-11-21 06:17:01 -08:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
jax authors
eaefabee85 Fixes to api_benchmark.py. Testcases always fail without these fixes.
PiperOrigin-RevId: 668299061
2024-08-27 23:37:49 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Roy Frostig
371935cc10 update README and several docs to typed RNG keys 2024-08-11 08:09:47 -07:00
Mark Sandler
4a91396e91 Adds make_array_from_callback_sharded benchmark
PiperOrigin-RevId: 646138175
2024-06-24 10:17:40 -07:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
jax authors
3fe7377719 Merge pull request #21763 from gnecula:export_api
PiperOrigin-RevId: 641959833
2024-06-10 11:05:34 -07:00
George Necula
b33aca6b08 [export] Create the jax.export module APIs.
The functionality comes from the jax.experimental.export
module, which will be deprecated.

The following APIs are introduced:

```
  from jax import export
  def f(...): ...
  ex: export.Exported = export.export(jax.jit(f))(*args, **kwargs)

  blob: bytearray = ex.serialize()
  rehydrated: export.Export = export.deserialize(blob)

  def caller(...):
     ... rehydrated.call(*args, **kwargs)
```

Module documentation will follow shortly.
There are no changes for now in the jax.experimental.export
APIs.

Most of the changes in this PR are in tests due to some differences
in the new jax.export APIs compared to jax.experimental.export:

  * Instead of `jax.experimental.export.call(exp)` we now write
    `exp.call`
  * The `jax.experimental.export.export` allowed the function
    argument to be any Python callable and it would wrap it with
    a `jax.jit`. This is not supported anymore by export, and instead
    the user must use `jax.jit`.
2024-06-10 19:31:51 +02:00
Junwhan Ahn
6617a0d1ed Expand device_put benchmarks to run with different numbers of arrays and input types
For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays.

PiperOrigin-RevId: 641716268
2024-06-09 13:01:51 -07:00
Christos Perivolaropoulos
24e4bf2265 [Mosaic GPU] Add f32 benchmarks for matmul.
PiperOrigin-RevId: 640826101
2024-06-06 02:31:40 -07:00
Christos Perivolaropoulos
8eaea2b13d [Mosaic GPU] Add a simple benchmark.
PiperOrigin-RevId: 639023867
2024-05-31 07:08:28 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Yash Katariya
0b542ff585 Add a benchmark measuring device_put's speed for a 4GB input array
```
---------------------------------------------------------
Benchmark               Time             CPU   Iterations
---------------------------------------------------------
device_put_big        419 ms        0.363 ms           10
```

PiperOrigin-RevId: 607512568
2024-02-15 17:47:10 -08:00
George Necula
1e7642279b [shape_poly] Add another micro-benchmark
A very common operation is to compute linear combinations of expressions.

PiperOrigin-RevId: 605644781
2024-02-09 09:02:59 -08:00
George Necula
313caba4cb [shape_poly] Add micro-benchmarks for symbolic shape manipulation.
In some of the cases when we use many symbolic expressions for shapes, the operations with
symbolic expressions are becoming somewhat costly. These benchmarks help with picking
good internal representations.

PiperOrigin-RevId: 604289958
2024-02-05 05:38:25 -08:00
Matthew Johnson
e31018b9c5 in partial_eval_custom rule for pjit, cache ClosedJaxpr creation
Anywhere we call the ClosedJaxpr constructor, we had better be under a cache.
We should audit the code...

Never trust comments, especially when blame says mattjj wrote them

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-12-21 15:35:45 -08:00
Peter Hawkins
cb182b8b22 Use a Jacobi SVD solver for unbatched SVDs up to 1024x1024 on NVIDIA GPUs.
The unbatched Jacobi solver is faster for small-moderate matrices, and the unbatched kernel doesn't have size restrictions.

Timings on T4 GPU:

Before:

------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           263587 ns       242274 ns         2780
svd/m:2/n:1           335561 ns       298238 ns         2303
svd/m:5/n:1           337784 ns       299841 ns         2304
svd/m:10/n:1          339184 ns       300703 ns         2311
svd/m:100/n:1         359826 ns       320088 ns         2159
svd/m:500/n:1         376124 ns       338660 ns         2076
svd/m:800/n:1         375779 ns       335590 ns         2060
svd/m:1000/n:1        419171 ns       341487 ns         2072
svd/m:1/n:2           307564 ns       270663 ns         2544
svd/m:2/n:2           320928 ns       283601 ns         2487
svd/m:5/n:2           377373 ns       344228 ns         2035
svd/m:10/n:2          380557 ns       349412 ns         1953
svd/m:100/n:2         435465 ns       403496 ns         1722
svd/m:500/n:2         444610 ns       410913 ns         1680
svd/m:800/n:2         454493 ns       416495 ns         1665
svd/m:1000/n:2        492110 ns       420539 ns         1665
svd/m:1/n:5           307316 ns       275833 ns         2531
svd/m:2/n:5           374318 ns       341432 ns         2086
svd/m:5/n:5           512928 ns       470293 ns         1361
svd/m:10/n:5          589330 ns       537070 ns         1353
svd/m:100/n:5         620164 ns       580166 ns         1193
svd/m:500/n:5         636424 ns       593692 ns         1180
svd/m:800/n:5         635545 ns       595016 ns         1181
svd/m:1000/n:5        672443 ns       597387 ns         1115
svd/m:1/n:10          310013 ns       273998 ns         2520
svd/m:2/n:10          370451 ns       334489 ns         2105
svd/m:5/n:10          560037 ns       522223 ns         1274
svd/m:10/n:10         572868 ns       535388 ns         1304
svd/m:100/n:10        959802 ns       918258 ns          765
svd/m:500/n:10        955958 ns       909778 ns          758
svd/m:800/n:10        924104 ns       879512 ns          777
svd/m:1000/n:10       950140 ns       883493 ns          775
svd/m:1/n:100         351237 ns       315554 ns         2198
svd/m:2/n:100         426883 ns       390089 ns         1792
svd/m:5/n:100         601557 ns       564493 ns         1255
svd/m:10/n:100        920819 ns       880011 ns          787
svd/m:100/n:100      7902281 ns      7229220 ns           95
svd/m:500/n:100      9720727 ns      9040679 ns           79
svd/m:800/n:100      9856378 ns      8998050 ns           79
svd/m:1000/n:100     9721017 ns      9086414 ns           79
svd/m:1/n:500         371171 ns       334217 ns         2117
svd/m:2/n:500         449165 ns       411499 ns         1700
svd/m:5/n:500         620354 ns       581866 ns         1185
svd/m:10/n:500        892375 ns       847239 ns          833
svd/m:100/n:500      9564810 ns      8867540 ns           79
svd/m:500/n:500    111924035 ns    104078023 ns            7
svd/m:800/n:500    147777319 ns    142730412 ns            5
svd/m:1000/n:500   154205084 ns    149740209 ns            5
svd/m:1/n:800         372122 ns       334212 ns         2119
svd/m:2/n:800         456672 ns       419260 ns         1680
svd/m:5/n:800         691208 ns       626003 ns         1190
svd/m:10/n:800       1017694 ns       941480 ns          730
svd/m:100/n:800      9892683 ns      9091043 ns           76
svd/m:500/n:800    144134235 ns    139129722 ns            5
svd/m:800/n:800    342790246 ns    333299774 ns            2
svd/m:1000/n:800   432820082 ns    427978978 ns            2
svd/m:1/n:1000        372785 ns       335745 ns         1805
svd/m:2/n:1000        451946 ns       413341 ns         1668
svd/m:5/n:1000        618475 ns       577213 ns         1169
svd/m:10/n:1000       907729 ns       863335 ns          808
svd/m:100/n:1000     9868543 ns      9116870 ns           76
svd/m:500/n:1000   156777811 ns    152042065 ns            5
svd/m:800/n:1000   429704070 ns    424677592 ns            2
svd/m:1000/n:1000  654864311 ns    642693162 ns            1

After:
------------------------------------------------------------
Benchmark                  Time             CPU   Iterations
------------------------------------------------------------
svd/m:1/n:1           265980 ns       245433 ns         2791
svd/m:2/n:1           340203 ns       302783 ns         2288
svd/m:5/n:1           337807 ns       301916 ns         2286
svd/m:10/n:1          338064 ns       302441 ns         2297
svd/m:100/n:1         335444 ns       298440 ns         2327
svd/m:500/n:1         338025 ns       302096 ns         2272
svd/m:800/n:1         328382 ns       291740 ns         2252
svd/m:1000/n:1        397494 ns       310905 ns         2239
svd/m:1/n:2           310464 ns       274507 ns         2535
svd/m:2/n:2           319999 ns       284247 ns         2515
svd/m:5/n:2           373435 ns       335919 ns         2069
svd/m:10/n:2          376327 ns       339327 ns         2056
svd/m:100/n:2         385061 ns       349258 ns         2003
svd/m:500/n:2         392352 ns       355735 ns         1932
svd/m:800/n:2         410736 ns       370677 ns         1881
svd/m:1000/n:2        494326 ns       405603 ns         1721
svd/m:1/n:5           316735 ns       277292 ns         2538
svd/m:2/n:5           383748 ns       342218 ns         2077
svd/m:5/n:5           494204 ns       454309 ns         1476
svd/m:10/n:5          547017 ns       508184 ns         1371
svd/m:100/n:5         514537 ns       476761 ns         1460
svd/m:500/n:5         544656 ns       504877 ns         1381
svd/m:800/n:5         642590 ns       599314 ns         1159
svd/m:1000/n:5        706166 ns       621209 ns         1106
svd/m:1/n:10          310825 ns       274374 ns         2511
svd/m:2/n:10          381316 ns       344202 ns         2094
svd/m:5/n:10          565469 ns       526759 ns         1266
svd/m:10/n:10         576111 ns       537286 ns         1299
svd/m:100/n:10        653250 ns       613392 ns         1137
svd/m:500/n:10        690532 ns       645828 ns         1080
svd/m:800/n:10        763924 ns       723677 ns          959
svd/m:1000/n:10       940342 ns       855517 ns          818
svd/m:1/n:100         306134 ns       271533 ns         2526
svd/m:2/n:100         374680 ns       339298 ns         2071
svd/m:5/n:100         576926 ns       539062 ns         1228
svd/m:10/n:100        656806 ns       615171 ns         1123
svd/m:100/n:100      3295164 ns      3138621 ns          223
svd/m:500/n:100      4269347 ns      4166000 ns          168
svd/m:800/n:100      4656541 ns      4522247 ns          154
svd/m:1000/n:100     6479223 ns      6354578 ns          112
svd/m:1/n:500         329966 ns       289083 ns         2440
svd/m:2/n:500         407535 ns       366794 ns         1947
svd/m:5/n:500         567367 ns       522809 ns         1336
svd/m:10/n:500        712307 ns       657608 ns         1065
svd/m:100/n:500      4262986 ns      4169907 ns          167
svd/m:500/n:500     28824720 ns     28650258 ns           25
svd/m:800/n:500     29330139 ns     28677269 ns           25
svd/m:1000/n:500    30848037 ns     30089216 ns           23
svd/m:1/n:800         328620 ns       289181 ns         2329
svd/m:2/n:800         419052 ns       379483 ns         1876
svd/m:5/n:800         587366 ns       546979 ns         1269
svd/m:10/n:800        830762 ns       787923 ns          893
svd/m:100/n:800      4763633 ns      4595738 ns          152
svd/m:500/n:800     30447861 ns     29949714 ns           24
svd/m:800/n:800     94188958 ns     93488372 ns            8
svd/m:1000/n:800    94701529 ns     93394677 ns            7
svd/m:1/n:1000        351102 ns       313099 ns         2218
svd/m:2/n:1000        446543 ns       407807 ns         1708
svd/m:5/n:1000        661152 ns       616174 ns         1129
svd/m:10/n:1000       915743 ns       873397 ns          802
svd/m:100/n:1000     6434730 ns      6282779 ns          113
svd/m:500/n:1000    30244321 ns     29684290 ns           24
svd/m:800/n:1000    92727423 ns     91477078 ns            8
svd/m:1000/n:1000  169500709 ns    168358420 ns            4
PiperOrigin-RevId: 582041508
2023-11-13 12:04:13 -08:00
Yash Katariya
fd09b35645 Optimize make_array_from_callback for fully replicated shardings by going via batched_device_put
Before:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%
```

After:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%
```

PiperOrigin-RevId: 572429822
2023-10-10 19:02:04 -07:00
Patrick Kidger
9d73441ff1 Added serial_dot_products benchmark 2023-09-21 15:25:52 -07:00
Roy Frostig
6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00
Jake VanderPlas
368d3433a6 Add random benchmarks
The purpose of this is to measure the difference in dispatch seed between raw keys and new-style typed keys. The latter does not yet hit the C++ fast path, and so we expect it to incur a small additional overhead at dispatch time. Part of #9263.

PiperOrigin-RevId: 559782442
2023-08-24 09:55:07 -07:00
jax authors
f498442daa [jax][benchmark] Added clearing caches for benchmarking compilation time in sparse JAX benchmarks
PiperOrigin-RevId: 553179605
2023-08-02 10:07:54 -07:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
74384e6a87 Add a C++ safe_zip implementation.
Benchmark results on my workstation:
```
name                                 old cpu/op   new cpu/op   delta
safe_zip/arg_lengths:0/num_args:1    1.22µs ± 1%  0.28µs ± 8%  -77.33%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:1    1.28µs ± 1%  0.34µs ± 6%  -73.18%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:1    1.28µs ± 1%  0.38µs ± 5%  -70.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:1    1.38µs ± 1%  0.51µs ± 3%  -63.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:1   1.61µs ± 1%  0.69µs ± 3%  -56.93%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:1  5.39µs ± 1%  3.83µs ± 2%  -29.03%  (p=0.008 n=5+5)
safe_zip/arg_lengths:0/num_args:2    1.46µs ± 1%  0.32µs ± 4%  -78.30%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:2    1.52µs ± 1%  0.39µs ± 4%  -74.20%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:2    1.53µs ± 1%  0.44µs ± 4%  -71.38%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:2    1.66µs ± 2%  0.60µs ± 3%  -63.96%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:2   1.90µs ± 1%  0.82µs ± 3%  -56.66%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:2  6.51µs ± 1%  4.80µs ± 0%  -26.23%  (p=0.016 n=5+4)
safe_zip/arg_lengths:0/num_args:3    1.62µs ± 1%  0.36µs ± 4%  -77.95%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:3    1.68µs ± 1%  0.44µs ± 3%  -73.75%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:3    1.69µs ± 1%  0.50µs ± 3%  -70.48%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:3    1.83µs ± 1%  0.68µs ± 2%  -62.73%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:3   2.12µs ± 1%  0.96µs ± 1%  -54.71%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:3  7.34µs ± 2%  5.89µs ± 1%  -19.74%  (p=0.008 n=5+5)
```

In addition, improve the length mismatch error for safe_map and define __module__ on both functions.

PiperOrigin-RevId: 523475834
2023-04-11 12:43:04 -07:00
Peter Hawkins
0dbd467cea Add a C++ implementation of safe map.
Before (argument names reversed, oops, fixed in code):

```
name                                 time/op
safe_map/num_args:0/arg_lengths:1    1.43µs ± 1%
safe_map/num_args:1/arg_lengths:1    1.61µs ± 1%
safe_map/num_args:2/arg_lengths:1    1.72µs ± 0%
safe_map/num_args:5/arg_lengths:1    2.14µs ± 1%
safe_map/num_args:10/arg_lengths:1   2.87µs ± 1%
safe_map/num_args:100/arg_lengths:1  15.6µs ± 1%
safe_map/num_args:0/arg_lengths:2    1.65µs ± 0%
safe_map/num_args:1/arg_lengths:2    1.83µs ± 1%
safe_map/num_args:2/arg_lengths:2    1.97µs ± 1%
safe_map/num_args:5/arg_lengths:2    2.41µs ± 1%
safe_map/num_args:10/arg_lengths:2   3.22µs ± 2%
safe_map/num_args:100/arg_lengths:2  17.0µs ± 2%
safe_map/num_args:0/arg_lengths:3    1.83µs ± 1%
safe_map/num_args:1/arg_lengths:3    2.02µs ± 1%
safe_map/num_args:2/arg_lengths:3    2.16µs ± 1%
safe_map/num_args:5/arg_lengths:3    2.63µs ± 1%
safe_map/num_args:10/arg_lengths:3   3.48µs ± 1%
safe_map/num_args:100/arg_lengths:3  18.1µs ± 1%
```

After:
```
name                                 time/op
safe_map/num_args:0/arg_lengths:1     409ns ± 1%
safe_map/num_args:1/arg_lengths:1     602ns ± 5%
safe_map/num_args:2/arg_lengths:1     777ns ± 4%
safe_map/num_args:5/arg_lengths:1    1.21µs ± 3%
safe_map/num_args:10/arg_lengths:1   1.93µs ± 2%
safe_map/num_args:100/arg_lengths:1  14.7µs ± 0%
safe_map/num_args:0/arg_lengths:2     451ns ± 1%
safe_map/num_args:1/arg_lengths:2     652ns ± 0%
safe_map/num_args:2/arg_lengths:2     850ns ± 4%
safe_map/num_args:5/arg_lengths:2    1.32µs ± 3%
safe_map/num_args:10/arg_lengths:2   2.11µs ± 2%
safe_map/num_args:100/arg_lengths:2  16.0µs ± 1%
safe_map/num_args:0/arg_lengths:3     496ns ± 1%
safe_map/num_args:1/arg_lengths:3     718ns ± 5%
safe_map/num_args:2/arg_lengths:3     919ns ± 4%
safe_map/num_args:5/arg_lengths:3    1.43µs ± 2%
safe_map/num_args:10/arg_lengths:3   2.30µs ± 2%
safe_map/num_args:100/arg_lengths:3  17.3µs ± 1%
```
PiperOrigin-RevId: 523263207
2023-04-10 18:09:56 -07:00
Yash Katariya
694e43a44a Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Yash Katariya
cf8c2b8450 Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py
PiperOrigin-RevId: 519742866
2023-03-27 09:22:57 -07:00
Yash Katariya
1faa7a8edd Add benchmarks for accessing index and replica id in addressable_shards
PiperOrigin-RevId: 517974091
2023-03-20 08:22:34 -07:00
Parker Schuh
48702171bf Add benchmarks for np.array, device_put, and _arrays.
PiperOrigin-RevId: 516692492
2023-03-14 19:06:06 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Emilio Cota
6f1d82916c math_benchmark: add --set_env flag
PiperOrigin-RevId: 515417422
2023-03-09 13:04:12 -08:00
Emilio Cota
845d68b39e math_benchmark: add dot op
PiperOrigin-RevId: 515408666
2023-03-09 12:24:47 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Lena Martens
4f48f94649 Update api_benchmark to not use any deprecated APIs.
PiperOrigin-RevId: 512941633
2023-02-28 08:33:26 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
d21ff0371f Remove gda_benchmark file as GDA is deprecated.
PiperOrigin-RevId: 510469600
2023-02-17 10:46:25 -08:00