87 Commits

Author SHA1 Message Date
pizzud
04def0b6ab lazy_loader_module: Move to new internal_test_util directory.
Now we no longer need to mess with sys.path in lazy_loader_test.

PiperOrigin-RevId: 515674188
2023-03-10 10:29:33 -08:00
Peter Hawkins
d58be3d4df Split source_info_util into its own Bazel target.
PiperOrigin-RevId: 515646269
2023-03-10 08:41:06 -08:00
Peter Hawkins
7bfd89a89c Split _src modules cloud_tpu_init, lazy_loader, path, monitoring into their own pytype_library Bazel targets.
PiperOrigin-RevId: 515420193
2023-03-09 13:11:04 -08:00
Peter Hawkins
7fd1e2ff47 Split _src/traceback_util.py into its own Bazel target.
Improve its type annotations.

PiperOrigin-RevId: 515376365
2023-03-09 10:33:47 -08:00
Peter Hawkins
9912a8eb56 Split _src/pretty_printer.py into its own Bazel target.
PiperOrigin-RevId: 515348089
2023-03-09 08:51:30 -08:00
Peter Hawkins
08789fd967 Exclude "util.py" and "config.py" from the main JAX bazel target.
This completes the process of splitting these targets out of :jax.

PiperOrigin-RevId: 515340312
2023-03-09 08:17:03 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
pizzud
22cbf95e07 lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.

PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
Peter Hawkins
0bb75afaa6 Remove global_device_array from shared jax bazel library.
Require Bazel users to depend explicitly on :global_device_array. Change in preparation for removing global device arrays.

PiperOrigin-RevId: 511273814
2023-02-21 12:27:44 -08:00
Peter Hawkins
f7734fd6a4 Limit visibility of Bazel target jax:global_device_array.
PiperOrigin-RevId: 510521459
2023-02-17 14:30:05 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
c368562529 Add keep_dep tag to :global_device_array build target to hint that it should be kept.
PiperOrigin-RevId: 510241400
2023-02-16 14:15:21 -08:00
Peter Hawkins
43b615c0a0 Move global_device_array into its own BUILD target.
PiperOrigin-RevId: 510229248
2023-02-16 13:30:40 -08:00
jax authors
b8d6efe22f Merge pull request #14273 from mattjj:shard-map
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Jake VanderPlas
66262901f0 [sparse] improve testing framework 2022-11-16 09:58:06 -08:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Jake VanderPlas
265b39d23f Add pytype_srcs to main jax BUILD rule
PiperOrigin-RevId: 476989241
2022-09-26 14:18:13 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
79406757d0 Remove deprecated jax.experimental.optimizers
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Jake VanderPlas
91dbcbf525 Remove deprecated jax.experimental.stax
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
George Necula
afa8f5acb4 Remove jax.experimental.loops. See CHANGELOG
PiperOrigin-RevId: 463297399
2022-07-26 03:39:47 -07:00
jax authors
b5e6145a42 Merge pull request #11359 from hawkinsp:bazel
PiperOrigin-RevId: 459234031
2022-07-06 06:13:20 -07:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Peter Hawkins
b03466a390 Remove jax/BUILD file.
This Bazel build file is unused; we only use Bazel to build jaxlib, which does not include any files from jax/.

PiperOrigin-RevId: 374709682
2021-05-19 12:55:34 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
George Necula
342d62436f [host_callback] Add support for pjit of host_callback.
Currently, all XLA side-effect ops inside a sharded computation must have
explicit sharding. This includes the outfeed and infeed used by host_callback.

The implementation here uses AssignDevice sharding for both the outfeed and the
infeed. This means that before outfeed, the devices will do an all_gather and
the first device will make the outfeed. The host callback will receive a single
outfeed with the entire array, and is supposed to return the entire array. This
gets sent to the same device that issued to outfeed, which is responsible to
send the respective slices to the other participating devices.

PiperOrigin-RevId: 370711606
2021-04-27 10:30:13 -07:00
Jake VanderPlas
90d606fe25 Remove jax.experimental.doubledouble
PiperOrigin-RevId: 369740697
2021-04-21 14:52:14 -07:00
Adam Paszke
f285f0cc47 Fix pjit resource checks to verify the shapes against the local mesh.
PiperOrigin-RevId: 369263691
2021-04-19 11:28:47 -07:00
Jake VanderPlas
0d4bcde7ca Add experimental/sparse_ops & cusparse wrappers in jaxlib
PiperOrigin-RevId: 368663407
2021-04-15 10:11:10 -07:00
jax authors
babf249705 Merge pull request #5717 from google:dynamic-shapes2
PiperOrigin-RevId: 357851603
2021-02-16 18:45:13 -08:00
jax authors
14cc3f89a5 Merge pull request #5672 from skye:pjit
PiperOrigin-RevId: 356774892
2021-02-10 10:57:06 -08:00
jax authors
543dcb37e3 Merge pull request #5527 from google:remove-soft-pmap
PiperOrigin-RevId: 354328027
2021-01-28 09:29:57 -08:00
jax authors
f4b5ff9d46 Merge pull request #5488 from jakevdp:x64-contextmanager
PiperOrigin-RevId: 353722592
2021-01-25 13:51:12 -08:00
Peter Hawkins
9bdc2ecc66 Consolidate build macros into a single jax.bzl file.
PiperOrigin-RevId: 352871429
2021-01-20 14:06:22 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
jax authors
2a699d0b04 Merge pull request #5038 from skye:sharded_jit_namespace
PiperOrigin-RevId: 346121390
2020-12-07 10:18:56 -08:00
Adam Paszke
d3136b44c8 Remove fake xmap resources, remove gmap
xmap can now handle real devices, so there's no point in maintaining the
simulated codepaths. Also, remove single-dimensional gmap as it will
have to be superseeded by a more xmap-friendly alternative.
2020-12-02 14:35:59 +00:00
Peter Hawkins
195e13c14b Remove jax.experimental.optix.
optix has become its own Python package (optax). You should use optax instead.

PiperOrigin-RevId: 343291598
2020-11-19 08:03:15 -08:00
Peter Hawkins
2b8d840cc3 [JAX] Remove uses of the deprecated jax.experimental.vectorize.
jax.numpy.vectorize should be used instead.

PiperOrigin-RevId: 341836454
2020-11-11 08:34:48 -08:00
Peter Hawkins
f58f1ee456 [JAX] Use PocketFFT for FFTs on CPU instead of Eigen.
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes https://github.com/google/jax/issues/2952

PiperOrigin-RevId: 338743753
2020-10-23 14:20:32 -07:00
Peter Hawkins
6acb46516e Move most of the implementation of jax.scipy into jax._src.scipy. 2020-10-16 17:04:25 -04:00