Parker Schuh
48702171bf
Add benchmarks for np.array, device_put, and _arrays.
...
PiperOrigin-RevId: 516692492
2023-03-14 19:06:06 -07:00
Yash Katariya
233911c001
[Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
...
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109
Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
...
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Emilio Cota
6f1d82916c
math_benchmark: add --set_env flag
...
PiperOrigin-RevId: 515417422
2023-03-09 13:04:12 -08:00
Emilio Cota
845d68b39e
math_benchmark: add dot op
...
PiperOrigin-RevId: 515408666
2023-03-09 12:24:47 -08:00
Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Lena Martens
4f48f94649
Update api_benchmark to not use any deprecated APIs.
...
PiperOrigin-RevId: 512941633
2023-02-28 08:33:26 -08:00
Yash Katariya
418c2f9d2a
Rename in_axis_resources
and out_axis_resources
with in_shardings
and out_shardings
. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
...
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
d21ff0371f
Remove gda_benchmark file as GDA is deprecated.
...
PiperOrigin-RevId: 510469600
2023-02-17 10:46:25 -08:00
Peter Hawkins
88cc254f2c
[JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
...
PiperOrigin-RevId: 508463147
2023-02-09 13:39:41 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
43e57db77a
Begin deprecation of public jax.ShapedArray
2023-01-30 11:27:58 -08:00
Emilio Cota
13e875f8b8
benchmarks: add math unary benchmarks
...
These will be used for benchmarking FP approximations in XLA.
PiperOrigin-RevId: 503991586
2023-01-23 08:17:16 -08:00
jax authors
eb875cd5dd
Added a pattern-match optimisation for inplace-select.
...
PiperOrigin-RevId: 497425937
2022-12-23 16:05:56 -08:00
Peter Hawkins
d6c67c97db
Remove redundant dtype canonicalization from jax.device_put().
...
Gives a small improvement to the included jax.device_put() benchmark on my VM:
```
name old cpu/op new cpu/op delta
device_put 91.3µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5)
name old time/op new time/op delta
device_put 91.4µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5)
```
jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement.
PiperOrigin-RevId: 491727173
2022-11-29 13:47:36 -08:00
Yash Katariya
928dee415f
Optimize host_local_array_to_global_array
by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify
and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
...
This results in a 5x speedup!
Before:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 3.03 ms 3.02 ms 220
```
After:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 0.673 ms 0.671 ms 985
```
PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
Yash Katariya
c42bad85ef
Make MeshPspecSharding
an alias for NamedSharding
(it was the other way around before this CL).
...
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
f4be5ab173
Merge pull request #12219 from jakevdp:indexing-slice
...
PiperOrigin-RevId: 485946084
2022-11-03 12:44:28 -07:00
Yash Katariya
532cd7ed74
Skip the benchmarks properly via state.skip_with_error when enough devices are not present.
...
PiperOrigin-RevId: 485931295
2022-11-03 11:44:57 -07:00
Jake VanderPlas
753562d574
Add benchmarks for repeated static indexing & slicing
2022-11-03 11:41:37 -07:00
Hyeontaek Lim
fc8f40ce0e
Internal visibility change
...
PiperOrigin-RevId: 484340424
2022-10-27 13:49:16 -07:00
Yash Katariya
cf6b5097d0
Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
...
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Yash Katariya
3572bb2db0
[Rollback]
...
Allow uncommitted single device PyArray in C++ pjit path.
PiperOrigin-RevId: 482084898
2022-10-18 19:42:10 -07:00
Kuangyuan Chen
d64da3d407
Roll forward with fix: Remove the original python function fun_
from C++ PjitFunction, as the destroying fun_
may yield the thread in some cases, which causes error during deleting the python object of PjitFunction.
...
PiperOrigin-RevId: 481950912
2022-10-18 10:05:53 -07:00
Kuangyuan Chen
fd2f590b3b
Allow uncommitted single device PyArray in C++ pjit path.
...
PiperOrigin-RevId: 481711690
2022-10-17 12:35:30 -07:00
jax authors
504b3c1b25
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Kuangyuan Chen
38a7582923
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
jax authors
1945208d34
Rollback because of failing tests internally.
...
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Kuangyuan Chen
d082ea0d46
Implement a fast path for pjit AOT in C++ for jax.Array inputs.
...
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00
Yash Katariya
84768d2d49
Replace jax.xla.DeviceArray
private type with the new public type jax.Array
.
...
PiperOrigin-RevId: 477582562
2022-09-28 16:34:10 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce
Implement pjit fast path in cpp for jax.Array inputs
...
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Kuangyuan Chen
b9e384363e
Add many args benchmark for jax.Array
...
PiperOrigin-RevId: 475853211
2022-09-21 09:54:51 -07:00
Kuangyuan Chen
b764aadbcf
Add pmap benchmarks using jax.Array
...
PiperOrigin-RevId: 473325120
2022-09-09 13:10:42 -07:00
Kuangyuan Chen
d17e516ea7
Add benchmarks for jax.Array
...
PiperOrigin-RevId: 471889808
2022-09-02 14:45:09 -07:00
Luca Di Grazia
88f6051233
target_total_secs has type int
but used as type None
...
"filename": "benchmarks/benchmark.py"
"warning_type": "Incompatible variable type [9]",
"warning_message": " target_total_secs is declared to have type `int` but is used as type `None`."
"warning_line": 86
"fix": int to Optional[int]
2022-08-22 15:46:38 +02:00
Matthew Johnson
b3b4ffbc21
make slicing compilation a little faster
...
For the common special case of int or slice-of-int/None indexing, we can
generate a lax.slice rather than a lax.gather. That makes compilation a little
faster, and makes the generated jaxpr a bit more wieldy too, to process in
transformations and to read when pretty-printed.
2022-08-19 09:14:37 -07:00
Matthew Johnson
42dd7cac43
simplify slicing jaxprs a little
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Matthew Johnson
e3a92d52ba
prepare to switch to new remat
...
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
* add benchmarks for new remat eager performance, and some caching to make those
benchmarks fast
* warn when the old-remat-exclusive `concrete` feature is used, with an
actionable message pointing to the new recommended approach involving static_argnums
* add the static_argnums parameter to both new and old remt
* update docstrings (and de-duplicate them to)
* add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
Yash Katariya
2109c6ec8c
Make is_compatible_aval
an optional method which sharding interfaces can implement to raise a more meaningful error. Otherwise lower to opsharding and catch the error if it fails.
...
PiperOrigin-RevId: 464147877
2022-07-29 13:37:15 -07:00
Yash Katariya
47623264db
Export HloSharding
via pybind which is a C++ wrapper around OpSharding proto.
...
PiperOrigin-RevId: 463992136
2022-07-28 21:01:15 -07:00
Yash Katariya
f4637c364d
Fix the gda_xla_sharding_match
benchmark which was regressing. This was happening because that function was executed from top to bottom a couple of times and each time a new mesh object was created violating the already created cache which doesn't happen in real life.
...
```
gda_xla_sharding_match_(256, 8)_PartitionSpec('x', 'y') 21.8ms ± 2% 1.3ms ± 2% -93.80% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(None,) 21.8ms ± 4% 1.3ms ± 1% -93.92% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('x',) 21.8ms ± 3% 1.3ms ± 1% -94.11% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec('y',) 21.8ms ± 3% 1.3ms ± 0% -94.12% (p=0.008 n=5+5)
gda_xla_sharding_match_(256, 8)_PartitionSpec(('x', 'y'),) 21.8ms ± 3% 1.3ms ± 1% -94.07% (p=0.008 n=5+5)
gda_xla_sharding_match_(128, 8)_PartitionSpec('x', 'y') 13.9ms ± 6% 1.3ms ± 1% -90.85% (p=0.008 n=5+5)
gda_xla_sharding_match_(4, 2)_PartitionSpec('x', 'y') 5.72ms ±10% 1.25ms ± 1% -78.15% (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec('x', 'y') 6.17ms ±11% 1.25ms ± 1% -79.71% (p=0.008 n=5+5)
gda_xla_sharding_match_(16, 4)_PartitionSpec(('x', 'y'),) 6.17ms ±10% 1.26ms ± 2% -79.61% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 463760534
2022-07-27 23:08:55 -07:00
Matthew Johnson
148173630f
add an optional fastpath for api_util.shaped_abstractify
...
also add a benchmark for it, 8.7ms -> 0.2ms on my machine
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2022-07-27 15:14:37 -07:00
Yash Katariya
d8cbb29d14
OpSharding doesn't have __eq__
defined on it. Don't check sharding equality using opsharding until it does support that.
...
PiperOrigin-RevId: 462238497
2022-07-20 15:03:39 -07:00
Jake VanderPlas
a10f0377db
Avoid top-level aliases of jax.tree_util.*
2022-07-07 11:41:02 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Yash Katariya
6a7a34603d
Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint).
...
This move is because sharded_jit is being deprecated.
PiperOrigin-RevId: 439948391
2022-04-06 15:19:19 -07:00
Yash Katariya
e08bc27bf0
Speed up GDA initialization by making local_shards lazy and hiding checks behind config.jax_enable_checks
flag.
...
PiperOrigin-RevId: 438115859
2022-03-29 13:51:11 -07:00