137 Commits

Author SHA1 Message Date
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
David Cottrell
40d0d40b6c Fix for log-normal prior in example. 2023-08-09 13:38:37 +01:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
9f03bd51ca Make differential privacy example work again.
* Use dp_accounting import instead of the TensorFlow privacy package.
* Correct example dataset import.
* Fix use of list argument to jnp.linalg.norm()

Fixes https://github.com/google/jax/issues/15500

PiperOrigin-RevId: 523168395
2023-04-10 11:31:06 -07:00
Peter Hawkins
ab45383038 Fix build breakage from OpenXLA switch.
PiperOrigin-RevId: 516325478
2023-03-13 14:37:35 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
88379603e0 [PJRT] Delete the old :cpu_device target that uses StreamExecutor.
The TFRT CPU client is better in every way and the SE CPU client is unmaintained and has not been used by JAX in many months.

PiperOrigin-RevId: 489246256
2022-11-17 10:29:03 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jean-Baptiste Lespiau
2929d79d0b Use value instead of ValueOrDie.
PiperOrigin-RevId: 473011038
2022-09-08 09:35:52 -07:00
jax authors
eb7040d6d6 Merge pull request #11629 from froystig:rm-control-example
PiperOrigin-RevId: 463462047
2022-07-26 17:12:51 -07:00
jax authors
32e77772b9 Rename xla::PjRtExecutable to xla::PjRtLoadedExecutable
PiperOrigin-RevId: 463460929
2022-07-26 17:06:45 -07:00
Roy Frostig
eadc3466c0 remove control example 2022-07-26 15:37:46 -07:00
Matthew Johnson
e350894371 remove resnet50 example 2022-07-19 09:40:39 -07:00
jax authors
b5e6145a42 Merge pull request #11359 from hawkinsp:bazel
PiperOrigin-RevId: 459234031
2022-07-06 06:13:20 -07:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Jake VanderPlas
1a995a0c61 [x64] make examples/control_test compatible with strict dtype promotion 2022-06-16 16:20:54 -07:00
Galen Andrew
ef9036abf3 Migrate from deprecated tensorflow_privacy RDP accountant to differential_privacy.
PiperOrigin-RevId: 454724315
2022-06-13 16:26:51 -07:00
jvmncs
c128876b34
incorrect link in dpsgd example 2022-05-19 16:35:07 -04:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
jax authors
d4c6f95331 Rename sync versions of ToLiteral to ToLiteralSync to facilitate upcoming refactor.
PiperOrigin-RevId: 432285004
2022-03-03 14:29:51 -08:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Jake VanderPlas
3f1d21ad73 examples tests: avoid use of private jax utilities 2021-12-10 11:42:36 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
02d7d837e6 [JAX] Replace uses of deprecated jax.ops.index_update(x, idx, y) APIs with their up-to-date, more succinct equivalent x.at[idx].set(y).
The JAX operators:
jax.ops.index_update(x, jax.ops.index[idx], y)
jax.ops.index_add(x, jax.ops.index[idx], y)
...

have long been deprecated in lieu of their more succinct counterparts:
x.at[idx].set(y)
x.at[idx].add(y)
...

This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.

The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays.

PiperOrigin-RevId: 400209692
2021-10-01 08:45:31 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Roy Frostig
9f8e3df320 fix advi example: pass tuples rather than lists to jitted function 2021-08-10 10:31:26 -07:00
slowy07
9eadb07bdc fix: miss typo codespell and documentation 2021-07-24 15:25:13 +07:00
Qiao Zhang
cf64f840e1 Remove redundant dep in jax cpp example target. 2021-06-22 14:49:19 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
jax authors
94e6314ba4 Merge pull request #6025 from bastings:patch-2
PiperOrigin-RevId: 363318137
2021-03-16 18:43:12 -07:00
Jasmijn Bastings
6470094f04
Fix import of datasets in differentially private SGD example 2021-03-11 11:32:43 +01:00
Skye Wanderman-Milne
cd619978bb Adjust precision in examples/kernel_lsq.py and corresponding test.
This is important on TPU.
2021-03-11 03:03:43 +00:00
jax authors
2e19b1d150 Merge pull request #5383 from zhangqiaorjc:jax_cpp
PiperOrigin-RevId: 353780826
2021-01-25 19:21:58 -08:00
Qiao Zhang
7e155244b6 Add example code to save JAX program and run using C++ runtime. 2021-01-25 19:12:26 -08:00
Peter Hawkins
160dfd343a Revert import path changes to examples/ and benchmarks/
PiperOrigin-RevId: 352911869
2021-01-20 17:35:55 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
Nicholas Vadivelu
238b4a1236
Update examples/differentially_private_sgd.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-10-05 00:54:18 -04:00
Nicholas Vadivelu
2ced1635b2 fix dp sgd example 2020-10-02 23:41:23 -07:00
Peter Hawkins
1b3dd65daf
Avoid lexically capturing the train_images value in MNIST VAE example. (#3947)
* Avoid lexically capturing the train_images value in MNIST VAE example.

This has the effect of baking in the training dataset as a constant, something that LLVM does not like that much.

* Add device_put to images.
2020-08-03 16:50:20 -04:00
Matthew Johnson
d10cf0e38f
fix prng key reuse in differential privacy example (#3646)
fix prng key reuse in differential privacy example
2020-07-02 14:29:17 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
6aa8f2461c
Fix remaining flakes and use exclude within setup.cfg (#3371) 2020-06-08 22:58:03 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00