63 Commits

Author SHA1 Message Date
Jake VanderPlas
79406757d0 Remove deprecated jax.experimental.optimizers
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Jake VanderPlas
91dbcbf525 Remove deprecated jax.experimental.stax
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
George Necula
afa8f5acb4 Remove jax.experimental.loops. See CHANGELOG
PiperOrigin-RevId: 463297399
2022-07-26 03:39:47 -07:00
jax authors
b5e6145a42 Merge pull request #11359 from hawkinsp:bazel
PiperOrigin-RevId: 459234031
2022-07-06 06:13:20 -07:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Peter Hawkins
b03466a390 Remove jax/BUILD file.
This Bazel build file is unused; we only use Bazel to build jaxlib, which does not include any files from jax/.

PiperOrigin-RevId: 374709682
2021-05-19 12:55:34 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
George Necula
342d62436f [host_callback] Add support for pjit of host_callback.
Currently, all XLA side-effect ops inside a sharded computation must have
explicit sharding. This includes the outfeed and infeed used by host_callback.

The implementation here uses AssignDevice sharding for both the outfeed and the
infeed. This means that before outfeed, the devices will do an all_gather and
the first device will make the outfeed. The host callback will receive a single
outfeed with the entire array, and is supposed to return the entire array. This
gets sent to the same device that issued to outfeed, which is responsible to
send the respective slices to the other participating devices.

PiperOrigin-RevId: 370711606
2021-04-27 10:30:13 -07:00
Jake VanderPlas
90d606fe25 Remove jax.experimental.doubledouble
PiperOrigin-RevId: 369740697
2021-04-21 14:52:14 -07:00
Adam Paszke
f285f0cc47 Fix pjit resource checks to verify the shapes against the local mesh.
PiperOrigin-RevId: 369263691
2021-04-19 11:28:47 -07:00
Jake VanderPlas
0d4bcde7ca Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
2021-04-15 10:11:10 -07:00
jax authors
babf249705 Merge pull request #5717 from google:dynamic-shapes2
PiperOrigin-RevId: 357851603
2021-02-16 18:45:13 -08:00
jax authors
14cc3f89a5 Merge pull request #5672 from skye:pjit
PiperOrigin-RevId: 356774892
2021-02-10 10:57:06 -08:00
jax authors
543dcb37e3 Merge pull request #5527 from google:remove-soft-pmap
PiperOrigin-RevId: 354328027
2021-01-28 09:29:57 -08:00
jax authors
f4b5ff9d46 Merge pull request #5488 from jakevdp:x64-contextmanager
PiperOrigin-RevId: 353722592
2021-01-25 13:51:12 -08:00
Peter Hawkins
9bdc2ecc66 Consolidate build macros into a single jax.bzl file.
PiperOrigin-RevId: 352871429
2021-01-20 14:06:22 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
jax authors
2a699d0b04 Merge pull request #5038 from skye:sharded_jit_namespace
PiperOrigin-RevId: 346121390
2020-12-07 10:18:56 -08:00
Adam Paszke
d3136b44c8 Remove fake xmap resources, remove gmap
xmap can now handle real devices, so there's no point in maintaining the
simulated codepaths. Also, remove single-dimensional gmap as it will
have to be superseeded by a more xmap-friendly alternative.
2020-12-02 14:35:59 +00:00
Peter Hawkins
195e13c14b Remove jax.experimental.optix.
optix has become its own Python package (optax). You should use optax instead.

PiperOrigin-RevId: 343291598
2020-11-19 08:03:15 -08:00
Peter Hawkins
2b8d840cc3 [JAX] Remove uses of the deprecated jax.experimental.vectorize.
jax.numpy.vectorize should be used instead.

PiperOrigin-RevId: 341836454
2020-11-11 08:34:48 -08:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Peter Hawkins
6acb46516e Move most of the implementation of jax.scipy into jax._src.scipy. 2020-10-16 17:04:25 -04:00
Jake VanderPlas
484ec3e2d6 Internal change
PiperOrigin-RevId: 336957861
2020-10-13 14:36:45 -07:00
Peter Hawkins
417de0d351
Add jit to jax.image.resize (#3714)
* Add image/ directory to Bazel build.

* Use a jit on jax.image.resize to reduce compilation time.

Relax bfloat16 test tolerance.
2020-07-10 10:32:13 -04:00
George Necula
2a9c2d22cf
Cleanup last use of msgpack. (#3668)
This is not needed for the new host_calback runtime.
2020-07-06 11:20:22 +03:00
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. (#3006)
This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00
George Necula
de685c9d5a An experiment for id_print implemented with outfeed
* Added print descriptors, support multiple types
* Added a state-passing mechanism to XLA interpreter
2020-05-07 16:24:13 +03:00
Peter Hawkins
1298e9e8c4
Fix some test failures. (#2713) 2020-04-14 18:23:19 -04:00
Peter Hawkins
2512ec6ebe
Glob over subdirectories in top-level BUILD file. (#2636)
Makes BUILD file more robust to directory structure changes.
2020-04-07 13:25:50 -04:00
Tom Hennigan
4c682b46bb
Add missing sources to jax build. (#2208) 2020-02-10 11:40:36 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Peter Hawkins
938a7f8012
Remove :libjax alias from BUILD file. (#1996) 2020-01-14 11:33:21 -05:00
Skye Wanderman-Milne
46014da21d Fix c45d9db ("Drop Python 2 support from jax BUILD rule. #1965") 2020-01-08 15:09:34 -08:00
Skye Wanderman-Milne
c45d9dbc20
Drop Python 2 support from jax BUILD rule. (#1965) 2020-01-08 15:03:47 -08:00
Skye Wanderman-Milne
3ae4a41320
Add "loops" BUILD target. (#1771) 2019-11-26 12:44:09 -08:00
Matteo Hessel
7644b31d98
Add build target for optix 2019-11-01 15:39:19 +00:00
Peter Hawkins
0dd720cd8a
Disable some tests that fail. (#1587)
Add a BUILD rule for experimental/vectorize.py.
2019-10-29 11:04:55 -04:00
Peter Hawkins
c485a3cc50
Remove stale reference to lapax.py. (#1546)
Add some missing documentation references.
2019-10-21 13:47:36 -04:00
Skye Wanderman-Milne
796d369efa Remove licenses() rule comment in BUILD files.
Internal tooling doesn't like it.
2019-09-26 14:54:07 -07:00
Matthew Johnson
c04f566407 remove extra newline to help copybara 2019-09-26 11:56:22 -07:00
Daniel Weaver
cd61fbfca3 Add ode library to BUILD file 2019-09-25 16:21:51 +00:00
Skye Wanderman-Milne
73d512bdd2 Add nn/*.py to jax/BUILD 2019-09-04 18:27:55 -07:00
Peter Hawkins
03aea7c167 Remove old dummy build_jax target. 2019-06-27 22:13:04 -04:00
Peter Hawkins
c44222276d Add lax/*.py to jax library dependencies.
Rename :libjax -> :jax. Leave :libjax as an alias for backward compatibility.
2019-04-16 09:29:13 -04:00
Peter Hawkins
854e3b1500 Add missing jax/ops/ to jax/BUILD.
Fix typo in ops/scatter.py
2019-03-06 09:07:15 -05:00