61 Commits

Author SHA1 Message Date
Sergei Lebedev
0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
Yash Katariya
6d8be966a0 Fix shard_map debug_nan leakage of manual out_avals in the impl rules of jit i.e. impl rule of jit saw a manual out_aval which is not expected. This is a band-aid for now with a TODO to do a proper fix
PiperOrigin-RevId: 730499532
2025-02-24 10:15:21 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Dougal Maclaurin
32bf19ac6f Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.
debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors.

PiperOrigin-RevId: 691494516
2024-10-30 11:34:08 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jake VanderPlas
e6e4acb7c3 tests: set configs with jtu.with_config rather than manually 2024-06-05 13:34:32 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
Yash Katariya
e624610e72 Replace apply_primitive internals with jax.jit.
This allows deletion of a lot of code and leads to ~40% eager performance speedup.

Benchmarks:

```
name                                                      old time/op          new time/op          delta
eager_unary_dispatch                                      31.3µs ± 1%          19.4µs ± 6%  -37.91%    (p=0.016 n=4+5)
eager_unary                                               32.1µs ± 0%          19.8µs ± 4%  -38.26%    (p=0.016 n=4+5)
eager_binary_dispatch                                     35.9µs ± 1%          20.5µs ± 4%  -42.93%    (p=0.016 n=4+5)
eager_binary                                              36.6µs ± 1%          21.1µs ± 4%  -42.29%    (p=0.016 n=4+5)
jit_trivial_dispatch                                      3.87µs ± 2%          4.12µs ±25%     ~       (p=1.000 n=5+5)
jit_trivial                                               4.75µs ± 2%          4.82µs ±11%     ~       (p=0.690 n=5+5)
jit_simple_dispatch                                       2.95µs ± 2%          2.97µs ± 7%     ~       (p=1.000 n=5+5)
jit_simple                                                3.52µs ± 6%          3.51µs ± 5%     ~       (p=0.841 n=5+5)
jit_simple_dispatch_array                                 2.95µs ± 2%          2.96µs ± 6%     ~       (p=1.000 n=5+5)
jit_simple_array                                          3.46µs ± 2%          3.51µs ± 5%     ~       (p=0.690 n=5+5)
jit_small_matmul                                          3.01µs ± 1%          3.00µs ± 4%     ~       (p=0.548 n=5+5)
jit_big_matmul                                            34.0µs ±18%          35.5µs ±17%     ~       (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10                 6.93µs ± 6%          6.80µs ± 6%     ~     (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100                47.7µs ± 7%          45.4µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000                545µs ± 8%           516µs ± 2%     ~      (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000               1.12ms ± 7%          1.07ms ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:10                          7.42µs ± 5%          7.23µs ± 2%     ~      (p=0.173 n=10+8)
jit_simple_many_args/num_args:100                         48.4µs ± 7%          45.6µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000                         542µs ± 6%           524µs ± 8%     ~     (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000                        1.12ms ± 7%          1.08ms ± 1%     ~      (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10                        4.79µs ± 8%          4.98µs ±10%     ~       (p=0.421 n=5+5)
jit_simple_pruned_args_10                                 5.32µs ± 6%          5.30µs ± 4%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100                       24.7µs ± 6%          23.8µs ± 8%     ~       (p=0.548 n=5+5)
jit_simple_pruned_args_100                                25.2µs ± 6%          24.4µs ± 8%     ~       (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000                       238µs ± 7%           232µs ± 8%     ~       (p=0.841 n=5+5)
jit_simple_pruned_args_1000                                240µs ± 7%           234µs ± 8%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000                       516µs ± 6%           497µs ± 1%     ~       (p=0.413 n=5+4)
jit_simple_pruned_args_2000                                517µs ± 6%           505µs ± 7%     ~       (p=0.690 n=5+5)
jit_dispatch_without_transfer                              719µs ± 9%           751µs ± 8%     ~       (p=0.222 n=5+5)
jit_dispatch_with_transfer                                 799µs ±14%           793µs ± 9%     ~       (p=1.000 n=5+5)
pmap_trivial_2_devices                                    49.9µs ±40%          48.2µs ±42%     ~       (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices                           74.5µs ±24%          78.9µs ±29%     ~       (p=0.421 n=5+5)
pmap_trivial_8_devices                                    79.3µs ± 6%          82.7µs ±20%     ~       (p=0.841 n=5+5)
pmap_simple_2_devices                                     47.1µs ±17%          49.1µs ±20%     ~       (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices                            73.4µs ±16%          76.8µs ±21%     ~       (p=0.690 n=5+5)
pmap_simple_8_devices                                     76.0µs ±10%          80.6µs ±29%     ~       (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args                   1.12ms ±22%          1.08ms ±42%     ~       (p=0.841 n=5+5)
pmap_simple_8_devices_100_args                            12.5ms ± 8%          12.8ms ±10%     ~       (p=1.000 n=5+5)
sda_index_1                                                413µs ± 1%           686µs ± 4%  +66.08%    (p=0.008 n=5+5)
sda_index_2                                                850µs ± 1%          1378µs ± 4%  +62.02%    (p=0.008 n=5+5)
sda_index_8                                               3.60ms ± 1%          5.69ms ± 4%  +58.00%    (p=0.008 n=5+5)
bench_shaped_abstractify                                   300µs ± 1%           305µs ± 3%     ~       (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int                          6.45µs ± 1%          6.50µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float                        3.73µs ± 1%          3.73µs ± 3%     ~       (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32                  4.97µs ± 1%          4.83µs ± 3%     ~       (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32                 4.91µs ± 1%          4.75µs ± 0%   -3.30%    (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random                        4.34µs ± 2%          4.31µs ± 3%     ~       (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32            3.94µs ± 1%          3.93µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_enum                                6.85µs ± 1%          7.06µs ± 7%   +3.07%    (p=0.032 n=5+5)
bench_are_op_shardings_equal                              26.9µs ± 2%          27.0µs ± 3%     ~       (p=0.841 n=5+5)
bench_pjit_check_aval_sharding                             691µs ± 2%           711µs ±13%     ~       (p=0.841 n=5+5)
bench_addressable_shards_index                             656ns ± 4%           688ns ± 9%     ~       (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads                     12.7ms ± 4%          10.7ms ± 1%  -15.48%    (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums      13.0ms ± 2%          11.3ms ± 6%  -13.71%    (p=0.008 n=5+5)
bench_slicing_compilation                                 12.1ms ± 1%          12.3ms ± 4%     ~       (p=0.690 n=5+5)
bench_slicing_compilation2                                11.3ms ± 0%          11.5ms ± 6%     ~       (p=0.690 n=5+5)
bench_repeated_static_indexing                            62.5ms ± 2%          40.8ms ± 8%  -34.77%    (p=0.008 n=5+5)
bench_repeated_static_slicing                             46.7ms ± 1%          31.4ms ± 2%  -32.76%    (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1                           2.72µs ± 2%          2.68µs ± 5%     ~       (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10                          12.6µs ± 7%          12.3µs ± 3%     ~       (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100                          109µs ± 3%           108µs ± 4%     ~       (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1                           38.0µs ±26%          36.8µs ±19%     ~       (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10                          93.3µs ±19%          96.6µs ±23%     ~       (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100                          730µs ±16%           698µs ±48%     ~       (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1                              3.29µs ± 2%          3.12µs ± 4%   -5.24%    (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10                             13.0µs ± 1%          12.7µs ± 2%     ~       (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100                             111µs ± 5%           110µs ±11%     ~       (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1                              38.4µs ±19%          38.9µs ±24%     ~       (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10                             91.3µs ±15%          96.9µs ±29%     ~       (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100                             676µs ±20%           689µs ±41%     ~       (p=0.841 n=5+5)
host_local_array_to_global_array                           196µs ± 6%           194µs ± 4%     ~       (p=0.548 n=5+5)
device_put                                                50.8µs ± 1%          50.7µs ± 4%     ~       (p=0.413 n=4+5)
device_put_sharded                                         176µs ± 0%           177µs ± 4%     ~       (p=0.190 n=4+5)
device_get_8_devices                                      3.96ms ± 4%          4.03ms ± 7%     ~       (p=0.413 n=4+5)
np_asarray_8_devices                                      3.34ms ±18%          3.30ms ±10%     ~       (p=0.548 n=5+5)
jax_array_arrays_8_devices                                5.01ms ±10%          5.09ms ±21%     ~       (p=0.421 n=5+5)
batch_inplace_while_scatter                                440µs ± 1%           439µs ± 1%     ~       (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice                   454µs ± 0%           457µs ± 1%     ~       (p=0.905 n=4+5)
serial_dot_products                                       4.51µs ± 3%          4.41µs ± 2%     ~       (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding  26.6µs ± 1%          27.0µs ± 2%     ~       (p=0.056 n=5+5)
```

PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
1c651f2ea4 Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
2f3d75aa03 Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
lenamartens
e80c34d624 Don't donate arguments in jit/pmap/pjit when debug_nans=True. 2022-11-08 13:33:59 +00:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Yash Katariya
fb8558cfdd Add jax_array coverage to debug_nans_test
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
Yash Katariya
9ff570e6c3 Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place.
PiperOrigin-RevId: 477850567
2022-09-29 16:29:58 -07:00
Yash Katariya
b4e1d0af8a Propagate name through ExecuteReplicated for dispatch.check_special
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00
Peter Hawkins
634f58c7d5 Enable a number of tests on GPU.
In particular, pjit/xmap work on CPU these days.

PiperOrigin-RevId: 446085110
2022-05-02 18:57:27 -07:00
Matthew Johnson
8bc8e40e72 debug_nans: don't return results of successfully running de-optimized function 2022-04-12 14:40:19 -07:00
Yash Katariya
687a7630ee Deprecate maps.mesh and replace it with maps.Mesh.
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Peter Hawkins
3fd3c46f20 Increase minimum jaxlib version to 0.1.74. 2021-11-18 15:06:58 -05:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Jean-Baptiste Lespiau
f6f1debf70 Add post_hook support for pmap, to support debug_nans and debug_infs.
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.

PiperOrigin-RevId: 391272270
2021-08-17 06:11:47 -07:00
Peter Hawkins
1aec989aa3 Fix "Store empty" error due to debug_nans corrupting cache entries.
Rather than mutating the existing WrappedFun, clone it with fresh stores. The stores aren't connected to anything, but that's fine: we can treat the deoptimized computation as a throwaway computation; the "real" computation is the jit-compiled version and we are ultimately going to use its stores if we don't throw an exception.
2021-08-11 10:59:42 -04:00
Jake VanderPlas
80d8f2d56c jnp.sinc: fix NaNs at x=0 2021-06-10 09:14:07 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Peter Hawkins
cac1b891ce [JAX] Refactor NaN/Inf checking in jitted functions.
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.

Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.

PiperOrigin-RevId: 365108787
2021-03-25 13:13:02 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Skye Wanderman-Milne
c56649aaac Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)

This change doesn't appear to regress performance:

```
---------Benchmark summary for pmap_shard_outputs---------
  nouts    nshards       mean      %std    relative    mean/baseline
-------  ---------  ---------  --------  ----------  ---------------
     10          8   0.105598  5.06671      1               1.00693
    100          8   0.287756  0.870751     2.72502         0.973204
    500          8   1.20119   0.823624    11.3752          0.955185
   1000          8   2.56071   0           24.2497          0.983063
   5000          8  12.909     0          122.247           0.965925
    100          2   0.173727  5.15115      1.64518         0.98918
    100          4   0.207774  3.71411      1.9676          0.955849
    100          8   0.286103  1.60243      2.70937         0.971869
    100        100   2.34168   0           22.1755          0.904475
    100        500  15.9558    0          151.1             1.00483
```

Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
George Necula
0e932aeb72
Update debug_nans_test.py
Fix typo
2021-01-15 14:40:37 +02:00
Mike Innes
0e73bb9281 inf checker tests 2021-01-06 14:43:05 +00:00
Jean-Baptiste Lespiau
5a097f5ca9 Gate some jax_jit test with a version check. 2020-10-12 20:04:19 +02:00
Jean-Baptiste Lespiau
c1e25953a3 Add support for jax_debug_nans and fix the last few glitches with the C++ jax.jit.
- Sorting the keyword arguments must be done on the string, because we go through the Python path which uses flatten() which sort them by string.
- Some error with obj == obj which is the same as obj.is(obj) and not obj.equal(obj).
- Moves all the Python tests to the C++ tests (which also run on the _python_jit).

PiperOrigin-RevId: 336671123
2020-10-12 08:50:13 -07:00
Matthew Johnson
24de811a39 move a debug_nans test into debug_nans test file 2020-10-08 13:34:56 -07:00
Matthew Johnson
09f2be15d2 wait for result in debug_nans_test 2020-10-08 13:00:32 -07:00