82 Commits

Author SHA1 Message Date
Peter Hawkins
08789fd967 Exclude "util.py" and "config.py" from the main JAX bazel target.
This completes the process of splitting these targets out of :jax.

PiperOrigin-RevId: 515340312
2023-03-09 08:17:03 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
pizzud
22cbf95e07 lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.

PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
Peter Hawkins
0bb75afaa6 Remove global_device_array from shared jax bazel library.
Require Bazel users to depend explicitly on :global_device_array. Change in preparation for removing global device arrays.

PiperOrigin-RevId: 511273814
2023-02-21 12:27:44 -08:00
Peter Hawkins
f7734fd6a4 Limit visibility of Bazel target jax:global_device_array.
PiperOrigin-RevId: 510521459
2023-02-17 14:30:05 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
c368562529 Add keep_dep tag to :global_device_array build target to hint that it should be kept.
PiperOrigin-RevId: 510241400
2023-02-16 14:15:21 -08:00
Peter Hawkins
43b615c0a0 Move global_device_array into its own BUILD target.
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
jax authors
b8d6efe22f Merge pull request #14273 from mattjj:shard-map
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Jake VanderPlas
66262901f0 [sparse] improve testing framework 2022-11-16 09:58:06 -08:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Jake VanderPlas
265b39d23f Add pytype_srcs to main jax BUILD rule
PiperOrigin-RevId: 476989241
2022-09-26 14:18:13 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
79406757d0 Remove deprecated jax.experimental.optimizers
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Jake VanderPlas
91dbcbf525 Remove deprecated jax.experimental.stax
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
George Necula
afa8f5acb4 Remove jax.experimental.loops. See CHANGELOG
PiperOrigin-RevId: 463297399
2022-07-26 03:39:47 -07:00
jax authors
b5e6145a42 Merge pull request #11359 from hawkinsp:bazel
PiperOrigin-RevId: 459234031
2022-07-06 06:13:20 -07:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Peter Hawkins
b03466a390 Remove jax/BUILD file.
This Bazel build file is unused; we only use Bazel to build jaxlib, which does not include any files from jax/.

PiperOrigin-RevId: 374709682
2021-05-19 12:55:34 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
George Necula
342d62436f [host_callback] Add support for pjit of host_callback.
Currently, all XLA side-effect ops inside a sharded computation must have
explicit sharding. This includes the outfeed and infeed used by host_callback.

The implementation here uses AssignDevice sharding for both the outfeed and the
infeed. This means that before outfeed, the devices will do an all_gather and
the first device will make the outfeed. The host callback will receive a single
outfeed with the entire array, and is supposed to return the entire array. This
gets sent to the same device that issued to outfeed, which is responsible to
send the respective slices to the other participating devices.

PiperOrigin-RevId: 370711606
2021-04-27 10:30:13 -07:00
Jake VanderPlas
90d606fe25 Remove jax.experimental.doubledouble
PiperOrigin-RevId: 369740697
2021-04-21 14:52:14 -07:00
Adam Paszke
f285f0cc47 Fix pjit resource checks to verify the shapes against the local mesh.
PiperOrigin-RevId: 369263691
2021-04-19 11:28:47 -07:00
Jake VanderPlas
0d4bcde7ca Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
2021-04-15 10:11:10 -07:00
jax authors
babf249705 Merge pull request #5717 from google:dynamic-shapes2
PiperOrigin-RevId: 357851603
2021-02-16 18:45:13 -08:00
jax authors
14cc3f89a5 Merge pull request #5672 from skye:pjit
PiperOrigin-RevId: 356774892
2021-02-10 10:57:06 -08:00
jax authors
543dcb37e3 Merge pull request #5527 from google:remove-soft-pmap
PiperOrigin-RevId: 354328027
2021-01-28 09:29:57 -08:00
jax authors
f4b5ff9d46 Merge pull request #5488 from jakevdp:x64-contextmanager
PiperOrigin-RevId: 353722592
2021-01-25 13:51:12 -08:00
Peter Hawkins
9bdc2ecc66 Consolidate build macros into a single jax.bzl file.
PiperOrigin-RevId: 352871429
2021-01-20 14:06:22 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
jax authors
2a699d0b04 Merge pull request #5038 from skye:sharded_jit_namespace
PiperOrigin-RevId: 346121390
2020-12-07 10:18:56 -08:00
Adam Paszke
d3136b44c8 Remove fake xmap resources, remove gmap
xmap can now handle real devices, so there's no point in maintaining the
simulated codepaths. Also, remove single-dimensional gmap as it will
have to be superseeded by a more xmap-friendly alternative.
2020-12-02 14:35:59 +00:00
Peter Hawkins
195e13c14b Remove jax.experimental.optix.
optix has become its own Python package (optax). You should use optax instead.

PiperOrigin-RevId: 343291598
2020-11-19 08:03:15 -08:00
Peter Hawkins
2b8d840cc3 [JAX] Remove uses of the deprecated jax.experimental.vectorize.
jax.numpy.vectorize should be used instead.

PiperOrigin-RevId: 341836454
2020-11-11 08:34:48 -08:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Peter Hawkins
6acb46516e Move most of the implementation of jax.scipy into jax._src.scipy. 2020-10-16 17:04:25 -04:00
Jake VanderPlas
484ec3e2d6 Internal change
PiperOrigin-RevId: 336957861
2020-10-13 14:36:45 -07:00
Peter Hawkins
417de0d351
Add jit to jax.image.resize (#3714)
* Add image/ directory to Bazel build.

* Use a jit on jax.image.resize to reduce compilation time.

Relax bfloat16 test tolerance.
2020-07-10 10:32:13 -04:00
George Necula
2a9c2d22cf
Cleanup last use of msgpack. (#3668)
This is not needed for the new host_calback runtime.
2020-07-06 11:20:22 +03:00
George Necula
c375adf52a
Implementation of id_tap/id_print using outfeed. (#3006)
This was already merged as #2791 but reverted due to XLA crashes.

This reverts commit 769d703b7ac1011babef6289382f1a14d7aafc42.
2020-05-08 17:18:11 +03:00
George Necula
769d703b7a Undo the id_print/id_tap feature (PR #2791)
Crashes on Travis with the latest 0.1.46. Need to figure out what is going on
2020-05-07 20:48:33 +03:00