Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
43e57db77a
Begin deprecation of public jax.ShapedArray
2023-01-30 11:27:58 -08:00
Emilio Cota
13e875f8b8
benchmarks: add math unary benchmarks
...
These will be used for benchmarking FP approximations in XLA.
PiperOrigin-RevId: 503991586
2023-01-23 08:17:16 -08:00
jax authors
eb875cd5dd
Added a pattern-match optimisation for inplace-select.
...
PiperOrigin-RevId: 497425937
2022-12-23 16:05:56 -08:00
Peter Hawkins
d6c67c97db
Remove redundant dtype canonicalization from jax.device_put().
...
Gives a small improvement to the included jax.device_put() benchmark on my VM:
```
name old cpu/op new cpu/op delta
device_put 91.3µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5)
name old time/op new time/op delta
device_put 91.4µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5)
```
jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement.
PiperOrigin-RevId: 491727173
2022-11-29 13:47:36 -08:00
Yash Katariya
928dee415f
Optimize host_local_array_to_global_array
by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify
and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
...
This results in a 5x speedup!
Before:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 3.03 ms 3.02 ms 220
```
After:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 0.673 ms 0.671 ms 985
```
PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
Yash Katariya
c42bad85ef
Make MeshPspecSharding
an alias for NamedSharding
(it was the other way around before this CL).
...
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
f4be5ab173
Merge pull request #12219 from jakevdp:indexing-slice
...
PiperOrigin-RevId: 485946084
2022-11-03 12:44:28 -07:00
Yash Katariya
532cd7ed74
Skip the benchmarks properly via state.skip_with_error when enough devices are not present.
...
PiperOrigin-RevId: 485931295
2022-11-03 11:44:57 -07:00
Jake VanderPlas
753562d574
Add benchmarks for repeated static indexing & slicing
2022-11-03 11:41:37 -07:00
Hyeontaek Lim
fc8f40ce0e
Internal visibility change
...
PiperOrigin-RevId: 484340424
2022-10-27 13:49:16 -07:00
Yash Katariya
cf6b5097d0
Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
...
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Yash Katariya
3572bb2db0
[Rollback]
...
Allow uncommitted single device PyArray in C++ pjit path.
PiperOrigin-RevId: 482084898
2022-10-18 19:42:10 -07:00
Kuangyuan Chen
d64da3d407
Roll forward with fix: Remove the original python function fun_
from C++ PjitFunction, as the destroying fun_
may yield the thread in some cases, which causes error during deleting the python object of PjitFunction.
...
PiperOrigin-RevId: 481950912
2022-10-18 10:05:53 -07:00
Kuangyuan Chen
fd2f590b3b
Allow uncommitted single device PyArray in C++ pjit path.
...
PiperOrigin-RevId: 481711690
2022-10-17 12:35:30 -07:00
jax authors
504b3c1b25
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Kuangyuan Chen
38a7582923
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
jax authors
1945208d34
Rollback because of failing tests internally.
...
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Kuangyuan Chen
d082ea0d46
Implement a fast path for pjit AOT in C++ for jax.Array inputs.
...
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00
Yash Katariya
84768d2d49
Replace jax.xla.DeviceArray
private type with the new public type jax.Array
.
...
PiperOrigin-RevId: 477582562
2022-09-28 16:34:10 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce
Implement pjit fast path in cpp for jax.Array inputs
...
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Kuangyuan Chen
b9e384363e
Add many args benchmark for jax.Array
...
PiperOrigin-RevId: 475853211
2022-09-21 09:54:51 -07:00
Kuangyuan Chen
b764aadbcf
Add pmap benchmarks using jax.Array
...
PiperOrigin-RevId: 473325120
2022-09-09 13:10:42 -07:00
Kuangyuan Chen
d17e516ea7
Add benchmarks for jax.Array
...
PiperOrigin-RevId: 471889808
2022-09-02 14:45:09 -07:00
Luca Di Grazia
88f6051233
target_total_secs has type int
but used as type None
...
"filename": "benchmarks/benchmark.py"
"warning_type": "Incompatible variable type [9]",
"warning_message": " target_total_secs is declared to have type `int` but is used as type `None`."
"warning_line": 86
"fix": int to Optional[int]
2022-08-22 15:46:38 +02:00
Matthew Johnson
b3b4ffbc21
make slicing compilation a little faster
...
For the common special case of int or slice-of-int/None indexing, we can
generate a lax.slice rather than a lax.gather. That makes compilation a little
faster, and makes the generated jaxpr a bit more wieldy too, to process in
transformations and to read when pretty-printed.
2022-08-19 09:14:37 -07:00
Matthew Johnson
42dd7cac43
simplify slicing jaxprs a little
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Matthew Johnson
e3a92d52ba
prepare to switch to new remat
...
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
* add benchmarks for new remat eager performance, and some caching to make those
benchmarks fast
* warn when the old-remat-exclusive `concrete` feature is used, with an
actionable message pointing to the new recommended approach involving static_argnums
* add the static_argnums parameter to both new and old remt
* update docstrings (and de-duplicate them to)
* add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
Yash Katariya
2109c6ec8c
Make is_compatible_aval
an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
...
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
Yash Katariya
47623264db
Export HloSharding
via pybind which is a C++ wrapper around OpSharding proto.
...
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
Yash Katariya
f4637c364d
Fix the gda_xla_sharding_match
benchmark which was regressing. This was happening because that function was executed from top to bottom a couple of times and each time a new mesh object was created violating the already created cache which doesn't happen in real life.
...
```
gda_xla_sharding_match_(256, 8)_PartitionSpec('x', 'y') 21.8ms ± 2% 1.3ms ± 2% -93.80% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(None,) 21.8ms ± 4% 1.3ms ± 1% -93.92% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('x',) 21.8ms ± 3% 1.3ms ± 1% -94.11% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('y',) 21.8ms ± 3% 1.3ms ± 0% -94.12% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(('x', 'y'),) 21.8ms ± 3% 1.3ms ± 1% -94.07% (p=0.008 n=5+5)
gda_xla_sharding_match_(128, 8)_PartitionSpec('x', 'y') 13.9ms ± 6% 1.3ms ± 1% -90.85% (p=0.008 n=5+5)
gda_xla_sharding_match_(4, 2)_PartitionSpec('x', 'y') 5.72ms ±10% 1.25ms ± 1% -78.15% (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec('x', 'y') 6.17ms ±11% 1.25ms ± 1% -79.71% (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec(('x', 'y'),) 6.17ms ±10% 1.26ms ± 2% -79.61% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 463760534
2022-07-27 23:08:55 -07:00
Matthew Johnson
148173630f
add an optional fastpath for api_util.shaped_abstractify
...
also add a benchmark for it, 8.7ms -> 0.2ms on my machine
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
Yash Katariya
d8cbb29d14
OpSharding doesn't have __eq__
defined on it. Don't check sharding equality using opsharding until it does support that.
...
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Jake VanderPlas
a10f0377db
Avoid top-level aliases of jax.tree_util.*
2022-07-07 11:41:02 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Yash Katariya
6a7a34603d
Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
...
This move is because sharded_jit is being deprecated.
PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Yash Katariya
e08bc27bf0
Speed up GDA initialization by making local_shards lazy and hiding checks behind config.jax_enable_checks
flag.
...
PiperOrigin-RevId: 438115859
2022-03-29 13:51:11 -07:00
Yash Katariya
a68b0f3a0a
Fix the mesh_axes of benchmarks to be Pspecs
...
PiperOrigin-RevId: 437932954
2022-03-28 21:51:01 -07:00
Qiao Zhang
5d7f639769
Add small and big matmul to api_benchmarks.
...
name cpu/op
jit_small_matmul 2.96µs ± 2%
jit_big_matmul 22.1µs ±21%
name time/op
jit_small_matmul 2.96µs ± 2%
jit_big_matmul 22.7µs ±21%
PiperOrigin-RevId: 435453853
2022-03-17 14:51:17 -07:00
Yash Katariya
4e47de66fc
Add the cache back now that Mesh's __hash__ is also being hashed on self.devices.shape
.
...
PiperOrigin-RevId: 425711067
2022-02-01 14:06:01 -08:00
Yash Katariya
f80887e69c
Couple of changes because of the serialization inconsistencies being observed.
...
* Remove the cache since one of the keys is global_mesh. Hash of global_mesh doesn't care of the mesh topology but just the devices. This is not good as indices are assigned to devices based on the topology. So if mesh is `(4, 2)` and then you give a new mesh `(2, 4)`, then the cache will return results for `(4, 2)` as the devices haven't changed. This is not right as the indices assigned in mesh `(2, 4)` will be different than `(4, 2)` as the topology is not the same.
```
mesh_devices = np.array(jax.devices()).reshape((4, 2))
mesh_axes = ('x' , 'y')
global_mesh1 = Mesh(mesh_devices, mesh_axes)
print(global_mesh1)
mesh_devices = np.array(jax.devices()).reshape((2, 4))
mesh_axes = ('x' , 'y')
global_mesh2 = Mesh(mesh_devices, mesh_axes)
print(global_mesh2)
hash(global_mesh1) == hash(global_mesh2)
Output:
Mesh(array([[0, 1],
[2, 3],
[4, 5],
[6, 7]]), ('x', 'y'))
Mesh(array([[0, 1, 2, 3],
[4, 5, 6, 7]]), ('x', 'y'))
True
```
* Remove the replica_id calculation math because after that the serialization inconsistencies were resolved. This is still confusing to me as to why its happening since I can't reproduce this in unit tests nor on small model runs. But I'll debug this in parallel. The important thing here is to unblock everyone. Replacing it with _hashed_index is still 2x faster than using _HashableIndex class.
PiperOrigin-RevId: 425525653
2022-01-31 20:48:17 -08:00
Yash Katariya
7f192c1946
Cache the expensive computations in GDA. For example get_shard_indices_replica_ids
can be the same for multiple variables in a neural network (global_shape, mesh_axes and global_mesh) can be the same
...
Note that the first time will be a little slow. The below timings you are seeing shows the caching working because the benchmark is running for multiple iterations and then the time is averaged over the number of iterations.
```
name time/op
gda_construction_callback_(4, 2)_['x', 'y'] 4.50ms ±10%
gda_construction_raw_(256, 8)_['x', 'y'] 5.82ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x', 'y'] 2.95ms ± 6%
indices_replica_id_calc_cached_(256, 8)_['x', 'y'] 28.7µs ± 1%
gda_construction_callback_(4, 2)_[None] 31.9ms ±20%
gda_construction_raw_(256, 8)_[None] 5.85ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_[None] 1.75ms ± 1%
indices_replica_id_calc_cached_(256, 8)_[None] 29.0µs ± 4%
gda_construction_callback_(4, 2)_['x'] 8.40ms ± 4%
gda_construction_raw_(256, 8)_['x'] 5.48ms ± 2%
indices_replica_id_calc__uncached_(256, 8)_['x'] 1.89ms ± 1%
indices_replica_id_calc_cached_(256, 8)_['x'] 29.0µs ± 4%
gda_construction_callback_(4, 2)_['y'] 15.3ms ± 6%
gda_construction_raw_(256, 8)_['y'] 5.66ms ± 5%
indices_replica_id_calc__uncached_(256, 8)_['y'] 1.82ms ± 2%
indices_replica_id_calc_cached_(256, 8)_['y'] 29.4µs ± 3%
gda_construction_callback_(4, 2)_[('x', 'y')] 4.29ms ± 5%
gda_construction_raw_(256, 8)_[('x', 'y')] 5.61ms ± 7%
indices_replica_id_calc__uncached_(256, 8)_[('x', 'y')] 3.81ms ±10%
indices_replica_id_calc_cached_(256, 8)_[('x', 'y')] 29.0µs ± 5%
gda_construction_raw_(128, 8)_['x', 'y'] 2.42ms ± 1%
indices_replica_id_calc__uncached_(128, 8)_['x', 'y'] 1.14ms ±11%
indices_replica_id_calc_cached_(128, 8)_['x', 'y'] 19.9µs ± 1%
gda_construction_raw_(4, 2)_['x', 'y'] 46.7µs ± 0%
indices_replica_id_calc__uncached_(4, 2)_['x', 'y'] 153µs ± 4%
indices_replica_id_calc_cached_(4, 2)_['x', 'y'] 11.1µs ± 8%
gda_construction_raw_(16, 4)_['x', 'y'] 164µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_['x', 'y'] 212µs ± 3%
indices_replica_id_calc_cached_(16, 4)_['x', 'y'] 11.3µs ± 1%
gda_construction_raw_(16, 4)_[('x', 'y')] 163µs ± 2%
indices_replica_id_calc__uncached_(16, 4)_[('x', 'y')] 210µs ± 2%
indices_replica_id_calc_cached_(16, 4)_[('x', 'y')] 11.6µs ± 8%
```
PiperOrigin-RevId: 422639127
2022-01-18 13:53:52 -08:00
Yash Katariya
0532a63261
Optimizations for GDA to make creating GDA faster.
...
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )
* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.
* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)
```
name old time/op new time/op delta
gda_construction_callback_(4, 2)_['x', 'y'] 4.77ms ± 5% 4.74ms ± 5% ~ (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y'] 17.9ms ± 5% 9.0ms ± 2% -49.92% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y'] 11.4ms ± 2% 2.9ms ± 2% -74.52% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None] 34.0ms ±20% 30.5ms ± 2% ~ (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None] 15.9ms ± 2% 7.7ms ± 3% -51.56% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None] 9.39ms ± 3% 1.74ms ± 2% -81.44% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x'] 8.87ms ± 2% 8.92ms ± 3% ~ (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x'] 16.4ms ± 2% 7.7ms ± 1% -52.66% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x'] 9.85ms ± 1% 1.90ms ± 2% -80.68% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y'] 15.9ms ± 3% 16.0ms ± 5% ~ (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y'] 15.8ms ± 3% 7.6ms ± 1% -52.04% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y'] 9.29ms ± 1% 1.78ms ± 1% -80.79% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')] 4.65ms ± 2% 4.62ms ± 3% ~ (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')] 18.6ms ± 3% 9.7ms ± 5% -47.76% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')] 11.8ms ± 4% 3.5ms ± 2% -70.28% (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y'] 8.54ms ± 1% 4.03ms ± 2% -52.84% (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y'] 5.40ms ± 4% 1.10ms ± 1% -79.69% (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y'] 173µs ± 1% 193µs ± 3% +11.63% (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y'] 127µs ± 1% 147µs ± 1% +15.57% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
Yash Katariya
fbb8b9f8c6
Benchmarks for GDA. Also move create_global_mesh
to test_utils since it was replicated in a lot of places.
...
PiperOrigin-RevId: 421142813
2022-01-11 15:43:05 -08:00
Tianjian Lu
6a0bfdd637
[JAX] Requires indices to be sorted and of int32 in _sparse_bcoo_matvec
test.
...
PiperOrigin-RevId: 417695937
2021-12-21 14:56:52 -08:00
Jake VanderPlas
b8296bb06c
benchmarks: add compilation of sparse operations
...
PiperOrigin-RevId: 405752775
2021-10-26 15:41:21 -07:00
Jake VanderPlas
c62452f2d2
benchmarks: add JIT versions of sparse.BCOO benchmarks
...
PiperOrigin-RevId: 405696495
2021-10-26 11:39:01 -07:00