90 Commits

Author SHA1 Message Date
Parker Schuh
48702171bf Add benchmarks for np.array, device_put, and _arrays.
PiperOrigin-RevId: 516692492
2023-03-14 19:06:06 -07:00
Yash Katariya
233911c001 [Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Emilio Cota
6f1d82916c math_benchmark: add --set_env flag
PiperOrigin-RevId: 515417422
2023-03-09 13:04:12 -08:00
Emilio Cota
845d68b39e math_benchmark: add dot op
PiperOrigin-RevId: 515408666
2023-03-09 12:24:47 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Lena Martens
4f48f94649 Update api_benchmark to not use any deprecated APIs.
PiperOrigin-RevId: 512941633
2023-02-28 08:33:26 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
d21ff0371f Remove gda_benchmark file as GDA is deprecated.
PiperOrigin-RevId: 510469600
2023-02-17 10:46:25 -08:00
Peter Hawkins
88cc254f2c [JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
PiperOrigin-RevId: 508463147
2023-02-09 13:39:41 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Emilio Cota
13e875f8b8 benchmarks: add math unary benchmarks
These will be used for benchmarking FP approximations in XLA.

PiperOrigin-RevId: 503991586
2023-01-23 08:17:16 -08:00
jax authors
eb875cd5dd Added a pattern-match optimisation for inplace-select.
PiperOrigin-RevId: 497425937
2022-12-23 16:05:56 -08:00
Peter Hawkins
d6c67c97db Remove redundant dtype canonicalization from jax.device_put().
Gives a small improvement to the included jax.device_put() benchmark on my VM:

```
name        old cpu/op  new cpu/op  delta
device_put  91.3µs ± 5%  80.1µs ± 3%  -12.29%  (p=0.008 n=5+5)

name        old time/op             new time/op             delta
device_put  91.4µs ± 5%             80.1µs ± 3%  -12.29%          (p=0.008 n=5+5)
```

jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement.

PiperOrigin-RevId: 491727173
2022-11-29 13:47:36 -08:00
Yash Katariya
928dee415f Optimize host_local_array_to_global_array by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
This results in a 5x speedup!

Before:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array       3.03 ms         3.02 ms          220
```

After:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array      0.673 ms        0.671 ms          985
```

PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
f4be5ab173 Merge pull request #12219 from jakevdp:indexing-slice
PiperOrigin-RevId: 485946084
2022-11-03 12:44:28 -07:00
Yash Katariya
532cd7ed74 Skip the benchmarks properly via state.skip_with_error when enough devices are not present.
PiperOrigin-RevId: 485931295
2022-11-03 11:44:57 -07:00
Jake VanderPlas
753562d574 Add benchmarks for repeated static indexing & slicing 2022-11-03 11:41:37 -07:00
Hyeontaek Lim
fc8f40ce0e Internal visibility change
PiperOrigin-RevId: 484340424
2022-10-27 13:49:16 -07:00
Yash Katariya
cf6b5097d0 Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Yash Katariya
3572bb2db0 [Rollback]
Allow uncommitted single device PyArray in C++ pjit path.

PiperOrigin-RevId: 482084898
2022-10-18 19:42:10 -07:00
Kuangyuan Chen
d64da3d407 Roll forward with fix: Remove the original python function fun_ from C++ PjitFunction, as the destroying fun_ may yield the thread in some cases, which causes error during deleting the python object of PjitFunction.
PiperOrigin-RevId: 481950912
2022-10-18 10:05:53 -07:00
Kuangyuan Chen
fd2f590b3b Allow uncommitted single device PyArray in C++ pjit path.
PiperOrigin-RevId: 481711690
2022-10-17 12:35:30 -07:00
jax authors
504b3c1b25 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Kuangyuan Chen
38a7582923 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
jax authors
1945208d34 Rollback because of failing tests internally.
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Kuangyuan Chen
d082ea0d46 Implement a fast path for pjit AOT in C++ for jax.Array inputs.
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00
Yash Katariya
84768d2d49 Replace jax.xla.DeviceArray private type with the new public type jax.Array.
PiperOrigin-RevId: 477582562
2022-09-28 16:34:10 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce Implement pjit fast path in cpp for jax.Array inputs
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Kuangyuan Chen
b9e384363e Add many args benchmark for jax.Array
PiperOrigin-RevId: 475853211
2022-09-21 09:54:51 -07:00
Kuangyuan Chen
b764aadbcf Add pmap benchmarks using jax.Array
PiperOrigin-RevId: 473325120
2022-09-09 13:10:42 -07:00
Kuangyuan Chen
d17e516ea7 Add benchmarks for jax.Array
PiperOrigin-RevId: 471889808
2022-09-02 14:45:09 -07:00
Luca Di Grazia
88f6051233
target_total_secs has type int but used as type None
"filename": "benchmarks/benchmark.py"
"warning_type": "Incompatible variable type [9]",
"warning_message": " target_total_secs is declared to have type `int` but is used as type `None`."
"warning_line": 86
"fix": int to Optional[int]
2022-08-22 15:46:38 +02:00
Matthew Johnson
b3b4ffbc21 make slicing compilation a little faster
For the common special case of int or slice-of-int/None indexing, we can
generate a lax.slice rather than a lax.gather. That makes compilation a little
faster, and makes the generated jaxpr a bit more wieldy too, to process in
transformations and to read when pretty-printed.
2022-08-19 09:14:37 -07:00
Matthew Johnson
42dd7cac43 simplify slicing jaxprs a little
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Matthew Johnson
e3a92d52ba prepare to switch to new remat
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
  * add benchmarks for new remat eager performance, and some caching to make those
    benchmarks fast
  * warn when the old-remat-exclusive `concrete` feature is used, with an
    actionable message pointing to the new recommended approach involving static_argnums
  * add the static_argnums parameter to both new and old remt
  * update docstrings (and de-duplicate them to)
  * add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
Yash Katariya
2109c6ec8c Make is_compatible_aval an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
Yash Katariya
47623264db Export HloSharding via pybind which is a C++ wrapper around OpSharding proto.
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
Yash Katariya
f4637c364d Fix the gda_xla_sharding_match benchmark which was regressing. This was happening because that function was executed from top to bottom a couple of times and each time a new mesh object was created violating the already created cache which doesn't happen in real life.
```
gda_xla_sharding_match_(256, 8)_PartitionSpec('x', 'y')     21.8ms ± 2%              1.3ms ± 2%  -93.80%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(None,)        21.8ms ± 4%              1.3ms ± 1%  -93.92%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('x',)         21.8ms ± 3%              1.3ms ± 1%  -94.11%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('y',)         21.8ms ± 3%              1.3ms ± 0%  -94.12%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(('x', 'y'),)  21.8ms ± 3%              1.3ms ± 1%  -94.07%          (p=0.008 n=5+5)
gda_xla_sharding_match_(128, 8)_PartitionSpec('x', 'y')     13.9ms ± 6%              1.3ms ± 1%  -90.85%          (p=0.008 n=5+5)
gda_xla_sharding_match_(4, 2)_PartitionSpec('x', 'y')       5.72ms ±10%             1.25ms ± 1%  -78.15%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec('x', 'y')      6.17ms ±11%             1.25ms ± 1%  -79.71%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec(('x', 'y'),)   6.17ms ±10%             1.26ms ± 2%  -79.61%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 463760534
2022-07-27 23:08:55 -07:00
Matthew Johnson
148173630f add an optional fastpath for api_util.shaped_abstractify
also add a benchmark for it, 8.7ms -> 0.2ms on my machine

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
Yash Katariya
d8cbb29d14 OpSharding doesn't have __eq__ defined on it. Don't check sharding equality using opsharding until it does support that.
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Yash Katariya
6a7a34603d Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
This move is because sharded_jit is being deprecated.

PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Yash Katariya
e08bc27bf0 Speed up GDA initialization by making local_shards lazy and hiding checks behind config.jax_enable_checks flag.
PiperOrigin-RevId: 438115859
2022-03-29 13:51:11 -07:00