Sergei Lebedev
0ff234049b
Removed trivial docstrings from JAX tests
...
These docstrings do not make the tests any more clear and typically just duplicate the test module name.
PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
Yash Katariya
6d8be966a0
Fix shard_map debug_nan leakage of manual out_avals in the impl rules of jit i.e. impl rule of jit saw a manual out_aval which is not expected. This is a band-aid for now with a TODO to do a proper fix
...
PiperOrigin-RevId: 730499532
2025-02-24 10:15:21 -08:00
Michael Hudgins
2e808f2836
Merge pull request #26279 from MichaelHudgins:tsan-resultstore
...
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Dougal Maclaurin
32bf19ac6f
Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
...
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.
PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
Michael Hudgins
d4d1518c3d
Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
...
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
0d5dae09ff
Delete xmap
and the jax.experimental.maps
module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
...
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jake VanderPlas
e6e4acb7c3
tests: set configs with jtu.with_config rather than manually
2024-06-05 13:34:32 -07:00
Jake VanderPlas
f090074d86
Avoid 'from jax import config' imports
...
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
84e49bd6ce
Remove internal references to deprecated jax.experimental.maps
2024-03-19 09:24:52 -07:00
Yash Katariya
e624610e72
Replace apply_primitive internals with jax.jit
.
...
This allows deletion of a lot of code and leads to ~40% eager performance speedup.
Benchmarks:
```
name old time/op new time/op delta
eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5)
eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5)
eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5)
eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5)
jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5)
jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5)
jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5)
jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5)
jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5)
jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5)
jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5)
jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8)
jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5)
jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5)
jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5)
jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4)
jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5)
jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5)
jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5)
pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5)
pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5)
pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5)
pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5)
pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5)
sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5)
sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5)
sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5)
bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5)
bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5)
bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5)
bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5)
bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5)
bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5)
bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5)
bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5)
bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5)
host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5)
device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5)
device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5)
device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5)
np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5)
jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5)
batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5)
serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5)
```
PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Yash Katariya
738dd719bd
Remove experimental_cpp_pmap flag since it is always on
...
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Peter Hawkins
dea7450e4e
Remove references to jax.config.jax_array, which is always True at head.
...
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
52a7701dda
Replace usage of {in|out}_axis_resources with {in|out}_shardings
...
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Yash Katariya
418c2f9d2a
Rename in_axis_resources
and out_axis_resources
with in_shardings
and out_shardings
. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
...
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
1c651f2ea4
Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
...
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
2f3d75aa03
Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
...
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Peter Hawkins
2c6c30d458
Bump the minimum jaxlib version to 0.4.1.
...
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
lenamartens
e80c34d624
Don't donate arguments in jit/pmap/pjit when debug_nans=True.
2022-11-08 13:33:59 +00:00
Peter Hawkins
c657449528
Copybara import of the project:
...
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Yash Katariya
fb8558cfdd
Add jax_array coverage to debug_nans_test
...
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
Yash Katariya
9ff570e6c3
Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place.
...
PiperOrigin-RevId: 477850567
2022-09-29 16:29:58 -07:00
Yash Katariya
b4e1d0af8a
Propagate name
through ExecuteReplicated for dispatch.check_special
...
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
7fbf8ec669
Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1
roll back breakage
...
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf
DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Jeppe Klitgaard
838a05329d
feat: validate jit args
2022-05-18 21:54:47 +01:00
Peter Hawkins
634f58c7d5
Enable a number of tests on GPU.
...
In particular, pjit/xmap work on CPU these days.
PiperOrigin-RevId: 446085110
2022-05-02 18:57:27 -07:00
Matthew Johnson
8bc8e40e72
debug_nans: don't return results of successfully running de-optimized function
2022-04-12 14:40:19 -07:00
Yash Katariya
687a7630ee
Deprecate maps.mesh
and replace it with maps.Mesh
.
...
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Peter Hawkins
3fd3c46f20
Increase minimum jaxlib version to 0.1.74.
2021-11-18 15:06:58 -05:00
Peter Hawkins
db2e91eba2
Move jax.test_util to jax._src.test_util.
...
Add forwarding shims for names used by external clients of JAX in practice.
PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc
Move contents of jax.lib to jax._src.lib.
...
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.
PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Jean-Baptiste Lespiau
f6f1debf70
Add post_hook support for pmap, to support debug_nans and debug_infs.
...
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.
PiperOrigin-RevId: 391272270
2021-08-17 06:11:47 -07:00
Peter Hawkins
1aec989aa3
Fix "Store empty" error due to debug_nans corrupting cache entries.
...
Rather than mutating the existing WrappedFun, clone it with fresh stores. The stores aren't connected to anything, but that's fine: we can treat the deoptimized computation as a throwaway computation; the "real" computation is the jit-compiled version and we are ultimately going to use its stores if we don't throw an exception.
2021-08-11 10:59:42 -04:00
Jake VanderPlas
80d8f2d56c
jnp.sinc: fix NaNs at x=0
2021-06-10 09:14:07 -07:00
Peter Hawkins
26e9ebcdae
Move jax.api to jax._src.api.
...
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
2b79264354
remove disable_omnistaging mechanism
2021-03-29 15:26:57 -07:00
Peter Hawkins
cac1b891ce
[JAX] Refactor NaN/Inf checking in jitted functions.
...
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.
Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.
PiperOrigin-RevId: 365108787
2021-03-25 13:13:02 -07:00
Matthew Johnson
fd7b286ec9
unify configuration state handling
2021-03-23 18:56:01 -07:00
Skye Wanderman-Milne
c56649aaac
Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
...
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)
This change doesn't appear to regress performance:
```
---------Benchmark summary for pmap_shard_outputs---------
nouts nshards mean %std relative mean/baseline
------- --------- --------- -------- ---------- ---------------
10 8 0.105598 5.06671 1 1.00693
100 8 0.287756 0.870751 2.72502 0.973204
500 8 1.20119 0.823624 11.3752 0.955185
1000 8 2.56071 0 24.2497 0.983063
5000 8 12.909 0 122.247 0.965925
100 2 0.173727 5.15115 1.64518 0.98918
100 4 0.207774 3.71411 1.9676 0.955849
100 8 0.286103 1.60243 2.70937 0.971869
100 100 2.34168 0 22.1755 0.904475
100 500 15.9558 0 151.1 1.00483
```
Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
5e7be4a61f
Cleanup: remove obsolete jaxlib version checks
2021-02-04 15:13:39 -08:00
George Necula
0e932aeb72
Update debug_nans_test.py
...
Fix typo
2021-01-15 14:40:37 +02:00
Mike Innes
0e73bb9281
inf checker tests
2021-01-06 14:43:05 +00:00
Jean-Baptiste Lespiau
5a097f5ca9
Gate some jax_jit test with a version check.
2020-10-12 20:04:19 +02:00
Jean-Baptiste Lespiau
c1e25953a3
Add support for jax_debug_nans and fix the last few glitches with the C++ jax.jit.
...
- Sorting the keyword arguments must be done on the string, because we go through the Python path which uses flatten() which sort them by string.
- Some error with obj == obj which is the same as obj.is(obj) and not obj.equal(obj).
- Moves all the Python tests to the C++ tests (which also run on the _python_jit).
PiperOrigin-RevId: 336671123
2020-10-12 08:50:13 -07:00
Matthew Johnson
24de811a39
move a debug_nans test into debug_nans test file
2020-10-08 13:34:56 -07:00
Matthew Johnson
09f2be15d2
wait for result in debug_nans_test
2020-10-08 13:00:32 -07:00