64 Commits

Author SHA1 Message Date
jax authors
504b3c1b25 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Kuangyuan Chen
38a7582923 roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
jax authors
1945208d34 Rollback because of failing tests internally.
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Kuangyuan Chen
d082ea0d46 Implement a fast path for pjit AOT in C++ for jax.Array inputs.
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00
Yash Katariya
84768d2d49 Replace jax.xla.DeviceArray private type with the new public type jax.Array.
PiperOrigin-RevId: 477582562
2022-09-28 16:34:10 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce Implement pjit fast path in cpp for jax.Array inputs
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Kuangyuan Chen
b9e384363e Add many args benchmark for jax.Array
PiperOrigin-RevId: 475853211
2022-09-21 09:54:51 -07:00
Kuangyuan Chen
b764aadbcf Add pmap benchmarks using jax.Array
PiperOrigin-RevId: 473325120
2022-09-09 13:10:42 -07:00
Kuangyuan Chen
d17e516ea7 Add benchmarks for jax.Array
PiperOrigin-RevId: 471889808
2022-09-02 14:45:09 -07:00
Luca Di Grazia
88f6051233
target_total_secs has type int but used as type None
"filename": "benchmarks/benchmark.py"
"warning_type": "Incompatible variable type [9]",
"warning_message": " target_total_secs is declared to have type `int` but is used as type `None`."
"warning_line": 86
"fix": int to Optional[int]
2022-08-22 15:46:38 +02:00
Matthew Johnson
b3b4ffbc21 make slicing compilation a little faster
For the common special case of int or slice-of-int/None indexing, we can
generate a lax.slice rather than a lax.gather. That makes compilation a little
faster, and makes the generated jaxpr a bit more wieldy too, to process in
transformations and to read when pretty-printed.
2022-08-19 09:14:37 -07:00
Matthew Johnson
42dd7cac43 simplify slicing jaxprs a little
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Matthew Johnson
e3a92d52ba prepare to switch to new remat
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
  * add benchmarks for new remat eager performance, and some caching to make those
    benchmarks fast
  * warn when the old-remat-exclusive `concrete` feature is used, with an
    actionable message pointing to the new recommended approach involving static_argnums
  * add the static_argnums parameter to both new and old remt
  * update docstrings (and de-duplicate them to)
  * add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
Yash Katariya
2109c6ec8c Make is_compatible_aval an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
Yash Katariya
47623264db Export HloSharding via pybind which is a C++ wrapper around OpSharding proto.
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
Yash Katariya
f4637c364d Fix the gda_xla_sharding_match benchmark which was regressing. This was happening because that function was executed from top to bottom a couple of times and each time a new mesh object was created violating the already created cache which doesn't happen in real life.
```
gda_xla_sharding_match_(256, 8)_PartitionSpec('x', 'y')     21.8ms ± 2%              1.3ms ± 2%  -93.80%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(None,)        21.8ms ± 4%              1.3ms ± 1%  -93.92%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('x',)         21.8ms ± 3%              1.3ms ± 1%  -94.11%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('y',)         21.8ms ± 3%              1.3ms ± 0%  -94.12%          (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(('x', 'y'),)  21.8ms ± 3%              1.3ms ± 1%  -94.07%          (p=0.008 n=5+5)
gda_xla_sharding_match_(128, 8)_PartitionSpec('x', 'y')     13.9ms ± 6%              1.3ms ± 1%  -90.85%          (p=0.008 n=5+5)
gda_xla_sharding_match_(4, 2)_PartitionSpec('x', 'y')       5.72ms ±10%             1.25ms ± 1%  -78.15%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec('x', 'y')      6.17ms ±11%             1.25ms ± 1%  -79.71%          (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec(('x', 'y'),)   6.17ms ±10%             1.26ms ± 2%  -79.61%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 463760534
2022-07-27 23:08:55 -07:00
Matthew Johnson
148173630f add an optional fastpath for api_util.shaped_abstractify
also add a benchmark for it, 8.7ms -> 0.2ms on my machine

Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
Yash Katariya
d8cbb29d14 OpSharding doesn't have __eq__ defined on it. Don't check sharding equality using opsharding until it does support that.
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Yash Katariya
6a7a34603d Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
This move is because sharded_jit is being deprecated.

PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Yash Katariya
e08bc27bf0 Speed up GDA initialization by making local_shards lazy and hiding checks behind config.jax_enable_checks flag.
PiperOrigin-RevId: 438115859
2022-03-29 13:51:11 -07:00
Yash Katariya
a68b0f3a0a Fix the mesh_axes of benchmarks to be Pspecs
PiperOrigin-RevId: 437932954
2022-03-28 21:51:01 -07:00
Qiao Zhang
5d7f639769 Add small and big matmul to api_benchmarks.
name                                  cpu/op
jit_small_matmul                      2.96µs ± 2%
jit_big_matmul                        22.1µs ±21%

name                                  time/op

jit_small_matmul                      2.96µs ± 2%
jit_big_matmul                        22.7µs ±21%

PiperOrigin-RevId: 435453853
2022-03-17 14:51:17 -07:00
Yash Katariya
4e47de66fc Add the cache back now that Mesh's __hash__ is also being hashed on self.devices.shape.
PiperOrigin-RevId: 425711067
2022-02-01 14:06:01 -08:00
Yash Katariya
f80887e69c Couple of changes because of the serialization inconsistencies being observed.
* Remove the cache since one of the keys is global_mesh. Hash of global_mesh doesn't care of the mesh topology but just the devices. This is not good as indices are assigned to devices based on the topology. So if mesh is `(4, 2)` and then you give a new mesh `(2, 4)`, then the cache will return results for `(4, 2)` as the devices haven't changed. This is not right as the indices assigned in mesh `(2, 4)` will be different than `(4, 2)` as the topology is not the same.

```
mesh_devices = np.array(jax.devices()).reshape((4, 2))
mesh_axes = ('x' , 'y')
global_mesh1 = Mesh(mesh_devices, mesh_axes)
print(global_mesh1)

mesh_devices = np.array(jax.devices()).reshape((2, 4))
mesh_axes = ('x' , 'y')
global_mesh2 = Mesh(mesh_devices, mesh_axes)
print(global_mesh2)

hash(global_mesh1) == hash(global_mesh2)

Output:

Mesh(array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7]]), ('x', 'y'))
Mesh(array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), ('x', 'y'))
True
```

* Remove the replica_id calculation math because after that the serialization inconsistencies were resolved. This is still confusing to me as to why its happening since I can't reproduce this in unit tests nor on small model runs. But I'll debug this in parallel. The important thing here is to unblock everyone. Replacing it with _hashed_index is still 2x faster than using _HashableIndex class.

PiperOrigin-RevId: 425525653
2022-01-31 20:48:17 -08:00
Yash Katariya
7f192c1946 Cache the expensive computations in GDA. For example get_shard_indices_replica_ids can be the same for multiple variables in a neural network (global_shape, mesh_axes and global_mesh) can be the same
Note that the first time will be a little slow. The below timings you are seeing shows the caching working because the benchmark is running for multiple iterations and then the time is averaged over the number of iterations.

```
name                                                     time/op
gda_construction_callback_(4, 2)_['x', 'y']              4.50ms ±10%
gda_construction_raw_(256, 8)_['x', 'y']                 5.82ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x', 'y']    2.95ms ± 6%
indices_replica_id_calc_cached_(256, 8)_['x', 'y']       28.7µs ± 1%
gda_construction_callback_(4, 2)_[None]                  31.9ms ±20%
gda_construction_raw_(256, 8)_[None]                     5.85ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_[None]        1.75ms ± 1%
indices_replica_id_calc_cached_(256, 8)_[None]           29.0µs ± 4%
gda_construction_callback_(4, 2)_['x']                   8.40ms ± 4%
gda_construction_raw_(256, 8)_['x']                      5.48ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x']         1.89ms ± 1%
indices_replica_id_calc_cached_(256, 8)_['x']            29.0µs ± 4%
gda_construction_callback_(4, 2)_['y']                   15.3ms ± 6%
gda_construction_raw_(256, 8)_['y']                      5.66ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_['y']         1.82ms ± 2%
indices_replica_id_calc_cached_(256, 8)_['y']            29.4µs ± 3%
gda_construction_callback_(4, 2)_[('x', 'y')]            4.29ms ± 5%
gda_construction_raw_(256, 8)_[('x', 'y')]               5.61ms ± 7%
indices_replica_id_calc__uncached_(256, 8)_[('x', 'y')]  3.81ms ±10%
indices_replica_id_calc_cached_(256, 8)_[('x', 'y')]     29.0µs ± 5%
gda_construction_raw_(128, 8)_['x', 'y']                 2.42ms ± 1%
indices_replica_id_calc__uncached_(128, 8)_['x', 'y']    1.14ms ±11%
indices_replica_id_calc_cached_(128, 8)_['x', 'y']       19.9µs ± 1%
gda_construction_raw_(4, 2)_['x', 'y']                   46.7µs ± 0%
indices_replica_id_calc__uncached_(4, 2)_['x', 'y']       153µs ± 4%
indices_replica_id_calc_cached_(4, 2)_['x', 'y']         11.1µs ± 8%
gda_construction_raw_(16, 4)_['x', 'y']                   164µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_['x', 'y']      212µs ± 3%
indices_replica_id_calc_cached_(16, 4)_['x', 'y']        11.3µs ± 1%
gda_construction_raw_(16, 4)_[('x', 'y')]                 163µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_[('x', 'y')]    210µs ± 2%
indices_replica_id_calc_cached_(16, 4)_[('x', 'y')]      11.6µs ± 8%
```

PiperOrigin-RevId: 422639127
2022-01-18 13:53:52 -08:00
Yash Katariya
0532a63261 Optimizations for GDA to make creating GDA faster.
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )

* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.

* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)

```
name                                           old time/op             new time/op             delta
gda_construction_callback_(4, 2)_['x', 'y']    4.77ms ± 5%             4.74ms ± 5%     ~           (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y']       17.9ms ± 5%              9.0ms ± 2%  -49.92%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y']    11.4ms ± 2%              2.9ms ± 2%  -74.52%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None]        34.0ms ±20%             30.5ms ± 2%     ~             (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None]           15.9ms ± 2%              7.7ms ± 3%  -51.56%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None]        9.39ms ± 3%             1.74ms ± 2%  -81.44%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x']         8.87ms ± 2%             8.92ms ± 3%     ~             (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x']            16.4ms ± 2%              7.7ms ± 1%  -52.66%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x']         9.85ms ± 1%             1.90ms ± 2%  -80.68%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y']         15.9ms ± 3%             16.0ms ± 5%     ~             (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y']            15.8ms ± 3%              7.6ms ± 1%  -52.04%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y']         9.29ms ± 1%             1.78ms ± 1%  -80.79%          (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')]  4.65ms ± 2%             4.62ms ± 3%     ~            (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')]     18.6ms ± 3%              9.7ms ± 5%  -47.76%          (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')]  11.8ms ± 4%              3.5ms ± 2%  -70.28%          (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y']       8.54ms ± 1%             4.03ms ± 2%  -52.84%          (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y']    5.40ms ± 4%             1.10ms ± 1%  -79.69%          (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y']          173µs ± 1%              193µs ± 3%  +11.63%          (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y']       127µs ± 1%              147µs ± 1%  +15.57%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
Yash Katariya
fbb8b9f8c6 Benchmarks for GDA. Also move create_global_mesh to test_utils since it was replicated in a lot of places.
PiperOrigin-RevId: 421142813
2022-01-11 15:43:05 -08:00
Tianjian Lu
6a0bfdd637 [JAX] Requires indices to be sorted and of int32 in _sparse_bcoo_matvec test.
PiperOrigin-RevId: 417695937
2021-12-21 14:56:52 -08:00
Jake VanderPlas
b8296bb06c benchmarks: add compilation of sparse operations
PiperOrigin-RevId: 405752775
2021-10-26 15:41:21 -07:00
Jake VanderPlas
c62452f2d2 benchmarks: add JIT versions of sparse.BCOO benchmarks
PiperOrigin-RevId: 405696495
2021-10-26 11:39:01 -07:00
Jake VanderPlas
5a3733788c benchmarks: add sparse.BCOO todense/fromdense/matvec benchmarks
PiperOrigin-RevId: 405519824
2021-10-25 16:44:36 -07:00
Yash Katariya
bfbdfa87e7 Add a warmup loop to pmap_simple_8_devices_100_args benchmark so as to not measure the compile time.
PiperOrigin-RevId: 401402336
2021-10-06 19:51:35 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
6a1b626564 Remove jax.api.
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00
Jean-Baptiste Lespiau
6cb8737c1a Add a benchmark with many arguments.
PiperOrigin-RevId: 393216026
2021-08-26 15:06:09 -07:00
Jean-Baptiste Lespiau
afdd195e42 Internal only.
PiperOrigin-RevId: 388562206
2021-08-03 15:49:17 -07:00
Qiao Zhang
850bd66242 [JAX] Prune unused inputs in jit.
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args

PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00
Tom Hennigan
9d56552517 Remove special casing on npy_value when indexing sharded arrays.
Before:

```
GPU
sda_index_2                           2912972 ns      2778716 ns          256

TPU
sda_index_1                            769968 ns       751700 ns          921
sda_index_2                           1510841 ns      1489716 ns          465
sda_index_8                           6102259 ns      6027655 ns          117
```

After:

```
GPU
sda_index_2                             28095 ns        27983 ns        25463

TPU
sda_index_1                             10302 ns        10279 ns        67884
sda_index_2                             20010 ns        19947 ns        34628
sda_index_8                             78492 ns        78306 ns         8934
```

PiperOrigin-RevId: 368380864
2021-04-14 01:17:10 -07:00
Matthew Johnson
9802f3378e add simple single-primitive eager benchmarks 2021-03-18 21:46:46 -07:00
Peter Hawkins
cdd36b1113 Improve API benchmarks.
Add benchmarks for different dispatch arg arities.

Add more blocking before and after benchmark loops that don't otherwise block.
2021-03-03 20:50:45 -05:00
Tom Hennigan
7b94c44af9 Sort imports.
PiperOrigin-RevId: 359532656
2021-02-25 08:52:09 -08:00
Peter Hawkins
160dfd343a Revert import path changes to examples/ and benchmarks/
PiperOrigin-RevId: 352911869
2021-01-20 17:35:55 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake Vanderplas
6a89f60683
fix benchmark sums (#4329) 2020-09-18 09:24:00 -07:00
Jean-Baptiste Lespiau
e95d5701e3
Add benchmarks for specifically the dispatch time. (#4128)
The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result.
We also add one to distinguish the time it takes to call the function with the argument transfer or without it.

e.g.

name                                   time/op
jit_trivial_dispatch                   28.9µs ± 2%
jit_trivial                            31.5µs ± 5%
jit_simple_dispatch                    60.7µs ± 4%
jit_simple                              129µs ±24%
jit_simple_many_args_disptch            390µs ±19%
jit_simple_many_args                    388µs ±16%
jit_dispatch_without_transfer           379µs ± 6%
jit_dispatch_with_transfer              450µs ± 5%
2020-08-27 17:02:13 +03:00