Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Yash Katariya
738dd719bd
Remove experimental_cpp_pmap flag since it is always on
...
PiperOrigin-RevId: 522631405
2023-04-07 10:42:11 -07:00
Peter Hawkins
dea7450e4e
Remove references to jax.config.jax_array, which is always True at head.
...
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Yash Katariya
52a7701dda
Replace usage of {in|out}_axis_resources with {in|out}_shardings
...
PiperOrigin-RevId: 513040164
2023-02-28 14:29:09 -08:00
Yash Katariya
418c2f9d2a
Rename in_axis_resources
and out_axis_resources
with in_shardings
and out_shardings
. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
...
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.
PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Yash Katariya
1c651f2ea4
Catch the NaN's and raise a better error message when jax_debug_nans flag is True.
...
PiperOrigin-RevId: 509552717
2023-02-14 09:27:36 -08:00
Peter Hawkins
428189f8fb
Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
...
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.
PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
2f3d75aa03
Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py.
...
PiperOrigin-RevId: 497230514
2022-12-22 13:35:23 -08:00
Peter Hawkins
2c6c30d458
Bump the minimum jaxlib version to 0.4.1.
...
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
lenamartens
e80c34d624
Don't donate arguments in jit/pmap/pjit when debug_nans=True.
2022-11-08 13:33:59 +00:00
Peter Hawkins
c657449528
Copybara import of the project:
...
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Yash Katariya
fb8558cfdd
Add jax_array coverage to debug_nans_test
...
PiperOrigin-RevId: 478079509
2022-09-30 14:21:32 -07:00
Yash Katariya
9ff570e6c3
Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place.
...
PiperOrigin-RevId: 477850567
2022-09-29 16:29:58 -07:00
Yash Katariya
b4e1d0af8a
Propagate name
through ExecuteReplicated for dispatch.check_special
...
PiperOrigin-RevId: 477351323
2022-09-27 21:32:32 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
7fbf8ec669
Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1
roll back breakage
...
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf
DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Jeppe Klitgaard
838a05329d
feat: validate jit args
2022-05-18 21:54:47 +01:00
Peter Hawkins
634f58c7d5
Enable a number of tests on GPU.
...
In particular, pjit/xmap work on CPU these days.
PiperOrigin-RevId: 446085110
2022-05-02 18:57:27 -07:00
Matthew Johnson
8bc8e40e72
debug_nans: don't return results of successfully running de-optimized function
2022-04-12 14:40:19 -07:00
Yash Katariya
687a7630ee
Deprecate maps.mesh
and replace it with maps.Mesh
.
...
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Peter Hawkins
3fd3c46f20
Increase minimum jaxlib version to 0.1.74.
2021-11-18 15:06:58 -05:00
Peter Hawkins
db2e91eba2
Move jax.test_util to jax._src.test_util.
...
Add forwarding shims for names used by external clients of JAX in practice.
PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc
Move contents of jax.lib to jax._src.lib.
...
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.
PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Jean-Baptiste Lespiau
f6f1debf70
Add post_hook support for pmap, to support debug_nans and debug_infs.
...
It's the exact same code as for JIT. We just modify the Python function to accept ShardedDeviceArray in addition to DeviceArray objects. The test is updated accordingly.
PiperOrigin-RevId: 391272270
2021-08-17 06:11:47 -07:00
Peter Hawkins
1aec989aa3
Fix "Store empty" error due to debug_nans corrupting cache entries.
...
Rather than mutating the existing WrappedFun, clone it with fresh stores. The stores aren't connected to anything, but that's fine: we can treat the deoptimized computation as a throwaway computation; the "real" computation is the jit-compiled version and we are ultimately going to use its stores if we don't throw an exception.
2021-08-11 10:59:42 -04:00
Jake VanderPlas
80d8f2d56c
jnp.sinc: fix NaNs at x=0
2021-06-10 09:14:07 -07:00
Peter Hawkins
26e9ebcdae
Move jax.api to jax._src.api.
...
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Matthew Johnson
2b79264354
remove disable_omnistaging mechanism
2021-03-29 15:26:57 -07:00
Peter Hawkins
cac1b891ce
[JAX] Refactor NaN/Inf checking in jitted functions.
...
Avoid performing NaN/Inf checking in the common path for calling a jit-ted function. Instead, add a global/thread-local `posthook` function that, if, set, the C++ jit code calls with the inputs (function, args, kwargs, outputs). Use the posthook feature to implement NaN checking.
Add a `_cache_miss` attribute to the C++ JIT function objects to allow the NaN checking code to extract and call the cache miss function.
PiperOrigin-RevId: 365108787
2021-03-25 13:13:02 -07:00
Matthew Johnson
fd7b286ec9
unify configuration state handling
2021-03-23 18:56:01 -07:00
Skye Wanderman-Milne
c56649aaac
Make jax_debug_nans and jax_debug_infs work with pmap, xmap, and pjit.
...
Note that unlike in the jit case, this doesn't rerun the function in
op-by-op mode when it finds a nan, since we don't have op-by-op
parallel execution yet :)
This change doesn't appear to regress performance:
```
---------Benchmark summary for pmap_shard_outputs---------
nouts nshards mean %std relative mean/baseline
------- --------- --------- -------- ---------- ---------------
10 8 0.105598 5.06671 1 1.00693
100 8 0.287756 0.870751 2.72502 0.973204
500 8 1.20119 0.823624 11.3752 0.955185
1000 8 2.56071 0 24.2497 0.983063
5000 8 12.909 0 122.247 0.965925
100 2 0.173727 5.15115 1.64518 0.98918
100 4 0.207774 3.71411 1.9676 0.955849
100 8 0.286103 1.60243 2.70937 0.971869
100 100 2.34168 0 22.1755 0.904475
100 500 15.9558 0 151.1 1.00483
```
Fixes #6044
2021-03-12 16:22:55 -08:00
Jake VanderPlas
5e7be4a61f
Cleanup: remove obsolete jaxlib version checks
2021-02-04 15:13:39 -08:00
George Necula
0e932aeb72
Update debug_nans_test.py
...
Fix typo
2021-01-15 14:40:37 +02:00
Mike Innes
0e73bb9281
inf checker tests
2021-01-06 14:43:05 +00:00
Jean-Baptiste Lespiau
5a097f5ca9
Gate some jax_jit test with a version check.
2020-10-12 20:04:19 +02:00
Jean-Baptiste Lespiau
c1e25953a3
Add support for jax_debug_nans and fix the last few glitches with the C++ jax.jit.
...
- Sorting the keyword arguments must be done on the string, because we go through the Python path which uses flatten() which sort them by string.
- Some error with obj == obj which is the same as obj.is(obj) and not obj.equal(obj).
- Moves all the Python tests to the C++ tests (which also run on the _python_jit).
PiperOrigin-RevId: 336671123
2020-10-12 08:50:13 -07:00
Matthew Johnson
24de811a39
move a debug_nans test into debug_nans test file
2020-10-08 13:34:56 -07:00
Matthew Johnson
09f2be15d2
wait for result in debug_nans_test
2020-10-08 13:00:32 -07:00
Jake VanderPlas
afce718eb1
Add ability to specify individual test targets
2020-06-29 11:08:57 -07:00
Peter Hawkins
66cea0277c
Fix test failures on GPU. ( #3572 )
2020-06-26 12:50:22 -04:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis ( #3304 )
...
* Cleanup: fix lint errors in tests/*.py
* Add flake8 step to travis
* add setup.cfg
2020-06-02 19:25:47 -07:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests ( #3276 )
2020-06-01 11:49:35 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. ( #2973 )
...
* Replace np -> jnp, onp -> np in more places.
Context: #2370
* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
George Necula
fcdbe63f37
Trigger a Travis build ( #2477 )
...
* Remove more unused imports
* Fix warnings in travis.yml
2020-03-21 17:38:46 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. ( #2117 )
...
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
James Bradbury
dd1e132eaa
test that debug_nans actually finds NaNs
2019-10-02 16:10:27 -07:00
James Bradbury
efc3a2c31a
address comments
2019-10-02 15:55:09 -07:00
James Bradbury
be4e156aee
use setUp and tearDown
2019-10-02 15:51:15 -07:00