76 Commits

Author SHA1 Message Date
Hyeontaek Lim
73b8f6aee2 [JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant
To prepare for the upcoming `BatchedDevicePut` implementation changes, this
change makes `make_array_from_callback_*` benchmark code to be more
homogeneous. Also it adds a variant that uses a partially replicated sharding.

PiperOrigin-RevId: 736665856
2025-03-13 15:50:46 -07:00
Hyeontaek Lim
178278863d [JAX] Fix api_benchmark broken by https://github.com/jax-ml/jax/pull/26569
`pjit_check_aval_sharding` expects `names: Sequence[str]`.

PiperOrigin-RevId: 734614264
2025-03-07 10:49:53 -08:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
jax authors
eaefabee85 Fixes to api_benchmark.py. Testcases always fail without these fixes.
PiperOrigin-RevId: 668299061
2024-08-27 23:37:49 -07:00
Roy Frostig
371935cc10 update README and several docs to typed RNG keys 2024-08-11 08:09:47 -07:00
Mark Sandler
4a91396e91 Adds make_array_from_callback_sharded benchmark
PiperOrigin-RevId: 646138175
2024-06-24 10:17:40 -07:00
Peter Hawkins
9e30079dba [JAX] Add caching to pjit._infer_params.
When tracing inner jits, we currently redo a lot of tracing work, which we can cache. Just as we have a C++ fast path for top-level jit calls, we can reuse the same logic for inner jits. We use part of the C++ fast path code to compute the signature of the arguments and split apart the dynamic arguments to compute a cache key. If we have seen the cache key before, we can avoid doing most of the work of _infer_params.

In passing, fix a bug where DynamicJaxprTracer's shaped_abstractify rule sometimes produces concrete avals.

```
name           old cpu/op   new cpu/op   delta
jit_add_chain  59.1ms ±14%  49.4ms ±10%  -16.32%  (p=0.008 n=5+5)

name           old time/op          new time/op          delta
jit_add_chain  60.3ms ±14%          50.7ms ±11%  -15.99%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 645491650
2024-06-21 13:53:04 -07:00
Junwhan Ahn
6617a0d1ed Expand device_put benchmarks to run with different numbers of arrays and input types
For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays.

PiperOrigin-RevId: 641716268
2024-06-09 13:01:51 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Yash Katariya
0b542ff585 Add a benchmark measuring device_put's speed for a 4GB input array
```
---------------------------------------------------------
Benchmark               Time             CPU   Iterations
---------------------------------------------------------
device_put_big        419 ms        0.363 ms           10
```

PiperOrigin-RevId: 607512568
2024-02-15 17:47:10 -08:00
Matthew Johnson
e31018b9c5 in partial_eval_custom rule for pjit, cache ClosedJaxpr creation
Anywhere we call the ClosedJaxpr constructor, we had better be under a cache.
We should audit the code...

Never trust comments, especially when blame says mattjj wrote them

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-12-21 15:35:45 -08:00
Yash Katariya
fd09b35645 Optimize make_array_from_callback for fully replicated shardings by going via batched_device_put
Before:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  467µs ± 3%
```

After:

```
name                                                      cpu/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%

name                                                      time/op
bench_make_array_from_callback_fully_replicated_sharding  28.1µs ± 2%
```

PiperOrigin-RevId: 572429822
2023-10-10 19:02:04 -07:00
Patrick Kidger
9d73441ff1 Added serial_dot_products benchmark 2023-09-21 15:25:52 -07:00
jax authors
f498442daa [jax][benchmark] Added clearing caches for benchmarking compilation time in sparse JAX benchmarks
PiperOrigin-RevId: 553179605
2023-08-02 10:07:54 -07:00
Yash Katariya
a6254c75e0 Improve the shape incompatible error message by adding the argument/result name path to it.
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
74384e6a87 Add a C++ safe_zip implementation.
Benchmark results on my workstation:
```
name                                 old cpu/op   new cpu/op   delta
safe_zip/arg_lengths:0/num_args:1    1.22µs ± 1%  0.28µs ± 8%  -77.33%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:1    1.28µs ± 1%  0.34µs ± 6%  -73.18%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:1    1.28µs ± 1%  0.38µs ± 5%  -70.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:1    1.38µs ± 1%  0.51µs ± 3%  -63.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:1   1.61µs ± 1%  0.69µs ± 3%  -56.93%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:1  5.39µs ± 1%  3.83µs ± 2%  -29.03%  (p=0.008 n=5+5)
safe_zip/arg_lengths:0/num_args:2    1.46µs ± 1%  0.32µs ± 4%  -78.30%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:2    1.52µs ± 1%  0.39µs ± 4%  -74.20%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:2    1.53µs ± 1%  0.44µs ± 4%  -71.38%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:2    1.66µs ± 2%  0.60µs ± 3%  -63.96%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:2   1.90µs ± 1%  0.82µs ± 3%  -56.66%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:2  6.51µs ± 1%  4.80µs ± 0%  -26.23%  (p=0.016 n=5+4)
safe_zip/arg_lengths:0/num_args:3    1.62µs ± 1%  0.36µs ± 4%  -77.95%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:3    1.68µs ± 1%  0.44µs ± 3%  -73.75%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:3    1.69µs ± 1%  0.50µs ± 3%  -70.48%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:3    1.83µs ± 1%  0.68µs ± 2%  -62.73%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:3   2.12µs ± 1%  0.96µs ± 1%  -54.71%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:3  7.34µs ± 2%  5.89µs ± 1%  -19.74%  (p=0.008 n=5+5)
```

In addition, improve the length mismatch error for safe_map and define __module__ on both functions.

PiperOrigin-RevId: 523475834
2023-04-11 12:43:04 -07:00
Peter Hawkins
0dbd467cea Add a C++ implementation of safe map.
Before (argument names reversed, oops, fixed in code):

```
name                                 time/op
safe_map/num_args:0/arg_lengths:1    1.43µs ± 1%
safe_map/num_args:1/arg_lengths:1    1.61µs ± 1%
safe_map/num_args:2/arg_lengths:1    1.72µs ± 0%
safe_map/num_args:5/arg_lengths:1    2.14µs ± 1%
safe_map/num_args:10/arg_lengths:1   2.87µs ± 1%
safe_map/num_args:100/arg_lengths:1  15.6µs ± 1%
safe_map/num_args:0/arg_lengths:2    1.65µs ± 0%
safe_map/num_args:1/arg_lengths:2    1.83µs ± 1%
safe_map/num_args:2/arg_lengths:2    1.97µs ± 1%
safe_map/num_args:5/arg_lengths:2    2.41µs ± 1%
safe_map/num_args:10/arg_lengths:2   3.22µs ± 2%
safe_map/num_args:100/arg_lengths:2  17.0µs ± 2%
safe_map/num_args:0/arg_lengths:3    1.83µs ± 1%
safe_map/num_args:1/arg_lengths:3    2.02µs ± 1%
safe_map/num_args:2/arg_lengths:3    2.16µs ± 1%
safe_map/num_args:5/arg_lengths:3    2.63µs ± 1%
safe_map/num_args:10/arg_lengths:3   3.48µs ± 1%
safe_map/num_args:100/arg_lengths:3  18.1µs ± 1%
```

After:
```
name                                 time/op
safe_map/num_args:0/arg_lengths:1     409ns ± 1%
safe_map/num_args:1/arg_lengths:1     602ns ± 5%
safe_map/num_args:2/arg_lengths:1     777ns ± 4%
safe_map/num_args:5/arg_lengths:1    1.21µs ± 3%
safe_map/num_args:10/arg_lengths:1   1.93µs ± 2%
safe_map/num_args:100/arg_lengths:1  14.7µs ± 0%
safe_map/num_args:0/arg_lengths:2     451ns ± 1%
safe_map/num_args:1/arg_lengths:2     652ns ± 0%
safe_map/num_args:2/arg_lengths:2     850ns ± 4%
safe_map/num_args:5/arg_lengths:2    1.32µs ± 3%
safe_map/num_args:10/arg_lengths:2   2.11µs ± 2%
safe_map/num_args:100/arg_lengths:2  16.0µs ± 1%
safe_map/num_args:0/arg_lengths:3     496ns ± 1%
safe_map/num_args:1/arg_lengths:3     718ns ± 5%
safe_map/num_args:2/arg_lengths:3     919ns ± 4%
safe_map/num_args:5/arg_lengths:3    1.43µs ± 2%
safe_map/num_args:10/arg_lengths:3   2.30µs ± 2%
safe_map/num_args:100/arg_lengths:3  17.3µs ± 1%
```
PiperOrigin-RevId: 523263207
2023-04-10 18:09:56 -07:00
Yash Katariya
694e43a44a Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Yash Katariya
1faa7a8edd Add benchmarks for accessing index and replica id in addressable_shards
PiperOrigin-RevId: 517974091
2023-03-20 08:22:34 -07:00
Parker Schuh
48702171bf Add benchmarks for np.array, device_put, and _arrays.
PiperOrigin-RevId: 516692492
2023-03-14 19:06:06 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Lena Martens
4f48f94649 Update api_benchmark to not use any deprecated APIs.
PiperOrigin-RevId: 512941633
2023-02-28 08:33:26 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
jax authors
eb875cd5dd Added a pattern-match optimisation for inplace-select.
PiperOrigin-RevId: 497425937
2022-12-23 16:05:56 -08:00
Peter Hawkins
d6c67c97db Remove redundant dtype canonicalization from jax.device_put().
Gives a small improvement to the included jax.device_put() benchmark on my VM:

```
name        old cpu/op  new cpu/op  delta
device_put  91.3µs ± 5%  80.1µs ± 3%  -12.29%  (p=0.008 n=5+5)

name        old time/op             new time/op             delta
device_put  91.4µs ± 5%             80.1µs ± 3%  -12.29%          (p=0.008 n=5+5)
```

jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement.

PiperOrigin-RevId: 491727173
2022-11-29 13:47:36 -08:00
Yash Katariya
928dee415f Optimize host_local_array_to_global_array by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
This results in a 5x speedup!

Before:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array       3.03 ms         3.02 ms          220
```

After:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array      0.673 ms        0.671 ms          985
```

PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
f4be5ab173 Merge pull request #12219 from jakevdp:indexing-slice
PiperOrigin-RevId: 485946084
2022-11-03 12:44:28 -07:00
Yash Katariya
532cd7ed74 Skip the benchmarks properly via state.skip_with_error when enough devices are not present.
PiperOrigin-RevId: 485931295
2022-11-03 11:44:57 -07:00
Jake VanderPlas
753562d574 Add benchmarks for repeated static indexing & slicing 2022-11-03 11:41:37 -07:00
Hyeontaek Lim
fc8f40ce0e Internal visibility change
PiperOrigin-RevId: 484340424
2022-10-27 13:49:16 -07:00
Yash Katariya
cf6b5097d0 Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Yash Katariya
3572bb2db0 [Rollback]
Allow uncommitted single device PyArray in C++ pjit path.

PiperOrigin-RevId: 482084898
2022-10-18 19:42:10 -07:00
Kuangyuan Chen
d64da3d407 Roll forward with fix: Remove the original python function fun_ from C++ PjitFunction, as the destroying fun_ may yield the thread in some cases, which causes error during deleting the python object of PjitFunction.
PiperOrigin-RevId: 481950912
2022-10-18 10:05:53 -07:00
Kuangyuan Chen
fd2f590b3b Allow uncommitted single device PyArray in C++ pjit path.
PiperOrigin-RevId: 481711690
2022-10-17 12:35:30 -07:00
jax authors
504b3c1b25 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Kuangyuan Chen
38a7582923 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
jax authors
1945208d34 Rollback because of failing tests internally.
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Kuangyuan Chen
d082ea0d46 Implement a fast path for pjit AOT in C++ for jax.Array inputs.
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce Implement pjit fast path in cpp for jax.Array inputs
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00