Yash Katariya
a6254c75e0
Improve the shape incompatible error message by adding the argument/result name path to it.
...
PiperOrigin-RevId: 529605855
2023-05-04 21:50:04 -07:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Jake VanderPlas
5521423d92
Change np.prod->math.prod
...
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Peter Hawkins
74384e6a87
Add a C++ safe_zip implementation.
...
Benchmark results on my workstation:
```
name old cpu/op new cpu/op delta
safe_zip/arg_lengths:0/num_args:1 1.22µs ± 1% 0.28µs ± 8% -77.33% (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:1 1.28µs ± 1% 0.34µs ± 6% -73.18% (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:1 1.28µs ± 1% 0.38µs ± 5% -70.26% (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:1 1.38µs ± 1% 0.51µs ± 3% -63.26% (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:1 1.61µs ± 1% 0.69µs ± 3% -56.93% (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:1 5.39µs ± 1% 3.83µs ± 2% -29.03% (p=0.008 n=5+5)
safe_zip/arg_lengths:0/num_args:2 1.46µs ± 1% 0.32µs ± 4% -78.30% (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:2 1.52µs ± 1% 0.39µs ± 4% -74.20% (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:2 1.53µs ± 1% 0.44µs ± 4% -71.38% (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:2 1.66µs ± 2% 0.60µs ± 3% -63.96% (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:2 1.90µs ± 1% 0.82µs ± 3% -56.66% (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:2 6.51µs ± 1% 4.80µs ± 0% -26.23% (p=0.016 n=5+4)
safe_zip/arg_lengths:0/num_args:3 1.62µs ± 1% 0.36µs ± 4% -77.95% (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:3 1.68µs ± 1% 0.44µs ± 3% -73.75% (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:3 1.69µs ± 1% 0.50µs ± 3% -70.48% (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:3 1.83µs ± 1% 0.68µs ± 2% -62.73% (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:3 2.12µs ± 1% 0.96µs ± 1% -54.71% (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:3 7.34µs ± 2% 5.89µs ± 1% -19.74% (p=0.008 n=5+5)
```
In addition, improve the length mismatch error for safe_map and define __module__ on both functions.
PiperOrigin-RevId: 523475834
2023-04-11 12:43:04 -07:00
Peter Hawkins
0dbd467cea
Add a C++ implementation of safe map.
...
Before (argument names reversed, oops, fixed in code):
```
name time/op
safe_map/num_args:0/arg_lengths:1 1.43µs ± 1%
safe_map/num_args:1/arg_lengths:1 1.61µs ± 1%
safe_map/num_args:2/arg_lengths:1 1.72µs ± 0%
safe_map/num_args:5/arg_lengths:1 2.14µs ± 1%
safe_map/num_args:10/arg_lengths:1 2.87µs ± 1%
safe_map/num_args:100/arg_lengths:1 15.6µs ± 1%
safe_map/num_args:0/arg_lengths:2 1.65µs ± 0%
safe_map/num_args:1/arg_lengths:2 1.83µs ± 1%
safe_map/num_args:2/arg_lengths:2 1.97µs ± 1%
safe_map/num_args:5/arg_lengths:2 2.41µs ± 1%
safe_map/num_args:10/arg_lengths:2 3.22µs ± 2%
safe_map/num_args:100/arg_lengths:2 17.0µs ± 2%
safe_map/num_args:0/arg_lengths:3 1.83µs ± 1%
safe_map/num_args:1/arg_lengths:3 2.02µs ± 1%
safe_map/num_args:2/arg_lengths:3 2.16µs ± 1%
safe_map/num_args:5/arg_lengths:3 2.63µs ± 1%
safe_map/num_args:10/arg_lengths:3 3.48µs ± 1%
safe_map/num_args:100/arg_lengths:3 18.1µs ± 1%
```
After:
```
name time/op
safe_map/num_args:0/arg_lengths:1 409ns ± 1%
safe_map/num_args:1/arg_lengths:1 602ns ± 5%
safe_map/num_args:2/arg_lengths:1 777ns ± 4%
safe_map/num_args:5/arg_lengths:1 1.21µs ± 3%
safe_map/num_args:10/arg_lengths:1 1.93µs ± 2%
safe_map/num_args:100/arg_lengths:1 14.7µs ± 0%
safe_map/num_args:0/arg_lengths:2 451ns ± 1%
safe_map/num_args:1/arg_lengths:2 652ns ± 0%
safe_map/num_args:2/arg_lengths:2 850ns ± 4%
safe_map/num_args:5/arg_lengths:2 1.32µs ± 3%
safe_map/num_args:10/arg_lengths:2 2.11µs ± 2%
safe_map/num_args:100/arg_lengths:2 16.0µs ± 1%
safe_map/num_args:0/arg_lengths:3 496ns ± 1%
safe_map/num_args:1/arg_lengths:3 718ns ± 5%
safe_map/num_args:2/arg_lengths:3 919ns ± 4%
safe_map/num_args:5/arg_lengths:3 1.43µs ± 2%
safe_map/num_args:10/arg_lengths:3 2.30µs ± 2%
safe_map/num_args:100/arg_lengths:3 17.3µs ± 1%
```
PiperOrigin-RevId: 523263207
2023-04-10 18:09:56 -07:00
Yash Katariya
694e43a44a
Remove experimental_cpp_jit
since that flag is unused and also remove experimental_cpp_pjit
.
...
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.
I am leaving pmap's flag alone for now.
PiperOrigin-RevId: 522602754
2023-04-07 08:29:20 -07:00
Peter Hawkins
452f3c55e3
Rename jax._src.sharding_utils to jax._src.op_shardings.
...
Move some more op_sharding related helpers to that module.
PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
Yash Katariya
cf8c2b8450
Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py
...
PiperOrigin-RevId: 519742866
2023-03-27 09:22:57 -07:00
Yash Katariya
1faa7a8edd
Add benchmarks for accessing index and replica id in addressable_shards
...
PiperOrigin-RevId: 517974091
2023-03-20 08:22:34 -07:00
Parker Schuh
48702171bf
Add benchmarks for np.array, device_put, and _arrays.
...
PiperOrigin-RevId: 516692492
2023-03-14 19:06:06 -07:00
Yash Katariya
233911c001
[Fix forward] Rollback the device_put_sharded and device_put_replicated change of using batched_device_put
...
PiperOrigin-RevId: 516244071
2023-03-13 10:07:44 -07:00
Peter Hawkins
1925aa1109
Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
...
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Emilio Cota
6f1d82916c
math_benchmark: add --set_env flag
...
PiperOrigin-RevId: 515417422
2023-03-09 13:04:12 -08:00
Emilio Cota
845d68b39e
math_benchmark: add dot op
...
PiperOrigin-RevId: 515408666
2023-03-09 12:24:47 -08:00
Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Lena Martens
4f48f94649
Update api_benchmark to not use any deprecated APIs.
...
PiperOrigin-RevId: 512941633
2023-02-28 08:33:26 -08:00
Yash Katariya
418c2f9d2a
Rename in_axis_resources
and out_axis_resources
with in_shardings
and out_shardings
. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
...
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
d21ff0371f
Remove gda_benchmark file as GDA is deprecated.
...
PiperOrigin-RevId: 510469600
2023-02-17 10:46:25 -08:00
Peter Hawkins
88cc254f2c
[JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
...
PiperOrigin-RevId: 508463147
2023-02-09 13:39:41 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Jake VanderPlas
43e57db77a
Begin deprecation of public jax.ShapedArray
2023-01-30 11:27:58 -08:00
Emilio Cota
13e875f8b8
benchmarks: add math unary benchmarks
...
These will be used for benchmarking FP approximations in XLA.
PiperOrigin-RevId: 503991586
2023-01-23 08:17:16 -08:00
jax authors
eb875cd5dd
Added a pattern-match optimisation for inplace-select.
...
PiperOrigin-RevId: 497425937
2022-12-23 16:05:56 -08:00
Peter Hawkins
d6c67c97db
Remove redundant dtype canonicalization from jax.device_put().
...
Gives a small improvement to the included jax.device_put() benchmark on my VM:
```
name old cpu/op new cpu/op delta
device_put 91.3µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5)
name old time/op new time/op delta
device_put 91.4µs ± 5% 80.1µs ± 3% -12.29% (p=0.008 n=5+5)
```
jax.device_put() has not been optimized that much yet and there is plenty of room for further improvement.
PiperOrigin-RevId: 491727173
2022-11-29 13:47:36 -08:00
Yash Katariya
928dee415f
Optimize host_local_array_to_global_array
by caching the local to global conversion and flattening of axis resources. Also take a fast path for device_put which does not do abstractify
and only canonicalize_dtype on the entire array once (instead of doing it for every shard).
...
This results in a 5x speedup!
Before:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 3.03 ms 3.02 ms 220
```
After:
```
---------------------------------------------------------------------------
Benchmark Time CPU Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array 0.673 ms 0.671 ms 985
```
PiperOrigin-RevId: 489880547
2022-11-20 20:53:02 -08:00
Yash Katariya
c42bad85ef
Make MeshPspecSharding
an alias for NamedSharding
(it was the other way around before this CL).
...
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
jax authors
f4be5ab173
Merge pull request #12219 from jakevdp:indexing-slice
...
PiperOrigin-RevId: 485946084
2022-11-03 12:44:28 -07:00
Yash Katariya
532cd7ed74
Skip the benchmarks properly via state.skip_with_error when enough devices are not present.
...
PiperOrigin-RevId: 485931295
2022-11-03 11:44:57 -07:00
Jake VanderPlas
753562d574
Add benchmarks for repeated static indexing & slicing
2022-11-03 11:41:37 -07:00
Hyeontaek Lim
fc8f40ce0e
Internal visibility change
...
PiperOrigin-RevId: 484340424
2022-10-27 13:49:16 -07:00
Yash Katariya
cf6b5097d0
Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
...
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Yash Katariya
3572bb2db0
[Rollback]
...
Allow uncommitted single device PyArray in C++ pjit path.
PiperOrigin-RevId: 482084898
2022-10-18 19:42:10 -07:00
Kuangyuan Chen
d64da3d407
Roll forward with fix: Remove the original python function fun_
from C++ PjitFunction, as the destroying fun_
may yield the thread in some cases, which causes error during deleting the python object of PjitFunction.
...
PiperOrigin-RevId: 481950912
2022-10-18 10:05:53 -07:00
Kuangyuan Chen
fd2f590b3b
Allow uncommitted single device PyArray in C++ pjit path.
...
PiperOrigin-RevId: 481711690
2022-10-17 12:35:30 -07:00
jax authors
504b3c1b25
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481666211
2022-10-17 09:50:55 -07:00
Kuangyuan Chen
38a7582923
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
...
PiperOrigin-RevId: 481181330
2022-10-14 10:42:15 -07:00
jax authors
1945208d34
Rollback because of failing tests internally.
...
PiperOrigin-RevId: 481103002
2022-10-14 03:12:42 -07:00
Kuangyuan Chen
d082ea0d46
Implement a fast path for pjit AOT in C++ for jax.Array inputs.
...
PiperOrigin-RevId: 480983807
2022-10-13 14:24:05 -07:00
Yash Katariya
84768d2d49
Replace jax.xla.DeviceArray
private type with the new public type jax.Array
.
...
PiperOrigin-RevId: 477582562
2022-09-28 16:34:10 -07:00
Yash Katariya
9e4114f0f1
Move array.py
and sharding.py
from experimental/
to _src/
.
...
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Kuangyuan Chen
405a2310ce
Implement pjit fast path in cpp for jax.Array inputs
...
PiperOrigin-RevId: 475988677
2022-09-21 20:18:18 -07:00
Kuangyuan Chen
b9e384363e
Add many args benchmark for jax.Array
...
PiperOrigin-RevId: 475853211
2022-09-21 09:54:51 -07:00
Kuangyuan Chen
b764aadbcf
Add pmap benchmarks using jax.Array
...
PiperOrigin-RevId: 473325120
2022-09-09 13:10:42 -07:00
Kuangyuan Chen
d17e516ea7
Add benchmarks for jax.Array
...
PiperOrigin-RevId: 471889808
2022-09-02 14:45:09 -07:00
Luca Di Grazia
88f6051233
target_total_secs has type int
but used as type None
...
"filename": "benchmarks/benchmark.py"
"warning_type": "Incompatible variable type [9]",
"warning_message": " target_total_secs is declared to have type `int` but is used as type `None`."
"warning_line": 86
"fix": int to Optional[int]
2022-08-22 15:46:38 +02:00
Matthew Johnson
b3b4ffbc21
make slicing compilation a little faster
...
For the common special case of int or slice-of-int/None indexing, we can
generate a lax.slice rather than a lax.gather. That makes compilation a little
faster, and makes the generated jaxpr a bit more wieldy too, to process in
transformations and to read when pretty-printed.
2022-08-19 09:14:37 -07:00
Matthew Johnson
42dd7cac43
simplify slicing jaxprs a little
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-12 19:52:02 -07:00
Matthew Johnson
e3a92d52ba
prepare to switch to new remat
...
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
* add benchmarks for new remat eager performance, and some caching to make those
benchmarks fast
* warn when the old-remat-exclusive `concrete` feature is used, with an
actionable message pointing to the new recommended approach involving static_argnums
* add the static_argnums parameter to both new and old remt
* update docstrings (and de-duplicate them to)
* add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00