51 Commits

Author SHA1 Message Date
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
738dd719bd Remove experimental_cpp_pmap flag since it is always on
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
52a7701dda Replace usage of {in|out}_axis_resources with {in|out}_shardings
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
1c651f2ea4 Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
2f3d75aa03 Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
lenamartens
e80c34d624 Don't donate arguments in jit/pmap/pjit when debug_nans=True. 2022-11-08 13:33:59 +00:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Yash Katariya
fb8558cfdd Add jax_array coverage to debug_nans_test
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
Yash Katariya
9ff570e6c3 Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place.
PiperOrigin-RevId: 477850567
2022-09-29 16:29:58 -07:00
Yash Katariya
b4e1d0af8a Propagate name through ExecuteReplicated for dispatch.check_special
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00
Peter Hawkins
634f58c7d5 Enable a number of tests on GPU.
In particular, pjit/xmap work on CPU these days.

PiperOrigin-RevId: 446085110
2022-05-02 18:57:27 -07:00
Matthew Johnson
8bc8e40e72 debug_nans: don't return results of successfully running de-optimized function 2022-04-12 14:40:19 -07:00
Yash Katariya
687a7630ee Deprecate maps.mesh and replace it with maps.Mesh.
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Peter Hawkins
3fd3c46f20 Increase minimum jaxlib version to 0.1.74. 2021-11-18 15:06:58 -05:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Jean-Baptiste Lespiau
f6f1debf70 Add post_hook support for pmap, to support debug_nans and debug_infs.
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.

PiperOrigin-RevId: 391272270
2021-08-17 06:11:47 -07:00
Peter Hawkins
1aec989aa3 Fix "Store empty" error due to debug_nans corrupting cache entries.
Rather than mutating the existing WrappedFun, clone it with fresh stores. The stores aren't connected to anything, but that's fine: we can treat the deoptimized computation as a throwaway computation; the "real" computation is the jit-compiled version and we are ultimately going to use its stores if we don't throw an exception.
2021-08-11 10:59:42 -04:00
Jake VanderPlas
80d8f2d56c jnp.sinc: fix NaNs at x=0 2021-06-10 09:14:07 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Peter Hawkins
cac1b891ce [JAX] Refactor NaN/Inf checking in jitted functions.
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.

Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.

PiperOrigin-RevId: 365108787
2021-03-25 13:13:02 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
Skye Wanderman-Milne
c56649aaac Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)

This change doesn't appear to regress performance:

```
---------Benchmark summary for pmap_shard_outputs---------
  nouts    nshards       mean      %std    relative    mean/baseline
-------  ---------  ---------  --------  ----------  ---------------
     10          8   0.105598  5.06671      1               1.00693
    100          8   0.287756  0.870751     2.72502         0.973204
    500          8   1.20119   0.823624    11.3752          0.955185
   1000          8   2.56071   0           24.2497          0.983063
   5000          8  12.909     0          122.247           0.965925
    100          2   0.173727  5.15115      1.64518         0.98918
    100          4   0.207774  3.71411      1.9676          0.955849
    100          8   0.286103  1.60243      2.70937         0.971869
    100        100   2.34168   0           22.1755          0.904475
    100        500  15.9558    0          151.1             1.00483
```

Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
George Necula
0e932aeb72
Update debug_nans_test.py
Fix typo
2021-01-15 14:40:37 +02:00
Mike Innes
0e73bb9281 inf checker tests 2021-01-06 14:43:05 +00:00
Jean-Baptiste Lespiau
5a097f5ca9 Gate some jax_jit test with a version check. 2020-10-12 20:04:19 +02:00
Jean-Baptiste Lespiau
c1e25953a3 Add support for jax_debug_nans and fix the last few glitches with the C++ jax.jit.
- Sorting the keyword arguments must be done on the string, because we go through the Python path which uses flatten() which sort them by string.
- Some error with obj == obj which is the same as obj.is(obj) and not obj.equal(obj).
- Moves all the Python tests to the C++ tests (which also run on the _python_jit).

PiperOrigin-RevId: 336671123
2020-10-12 08:50:13 -07:00
Matthew Johnson
24de811a39 move a debug_nans test into debug_nans test file 2020-10-08 13:34:56 -07:00
Matthew Johnson
09f2be15d2 wait for result in debug_nans_test 2020-10-08 13:00:32 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
66cea0277c
Fix test failures on GPU. (#3572) 2020-06-26 12:50:22 -04:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis (#3304)
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
2020-06-02 19:25:47 -07:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
George Necula
fcdbe63f37
Trigger a Travis build (#2477)
* Remove more unused imports
* Fix warnings in travis.yml
2020-03-21 17:38:46 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
dd1e132eaa test that debug_nans actually finds NaNs 2019-10-02 16:10:27 -07:00
James Bradbury
efc3a2c31a address comments 2019-10-02 15:55:09 -07:00
James Bradbury
be4e156aee use setUp and tearDown 2019-10-02 15:51:15 -07:00