127 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jean-Baptiste Lespiau
2929d79d0b Use value instead of ValueOrDie.
PiperOrigin-RevId: 473011038
2022-09-08 09:35:52 -07:00
jax authors
eb7040d6d6 Merge pull request #11629 from froystig:rm-control-example
PiperOrigin-RevId: 463462047
2022-07-26 17:12:51 -07:00
jax authors
32e77772b9 Rename xla::PjRtExecutable to xla::PjRtLoadedExecutable
PiperOrigin-RevId: 463460929
2022-07-26 17:06:45 -07:00
Roy Frostig
eadc3466c0 remove control example 2022-07-26 15:37:46 -07:00
Matthew Johnson
e350894371 remove resnet50 example 2022-07-19 09:40:39 -07:00
jax authors
b5e6145a42 Merge pull request #11359 from hawkinsp:bazel
PiperOrigin-RevId: 459234031
2022-07-06 06:13:20 -07:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Jake VanderPlas
1a995a0c61 [x64] make examples/control_test compatible with strict dtype promotion 2022-06-16 16:20:54 -07:00
Galen Andrew
ef9036abf3 Migrate from deprecated tensorflow_privacy RDP accountant to differential_privacy.
PiperOrigin-RevId: 454724315
2022-06-13 16:26:51 -07:00
jvmncs
c128876b34
incorrect link in dpsgd example 2022-05-19 16:35:07 -04:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
jax authors
d4c6f95331 Rename sync versions of ToLiteral to ToLiteralSync to facilitate upcoming refactor.
PiperOrigin-RevId: 432285004
2022-03-03 14:29:51 -08:00
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Jake VanderPlas
3f1d21ad73 examples tests: avoid use of private jax utilities 2021-12-10 11:42:36 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
02d7d837e6 [JAX] Replace uses of deprecated jax.ops.index_update(x, idx, y) APIs with their up-to-date, more succinct equivalent x.at[idx].set(y).
The JAX operators:
jax.ops.index_update(x, jax.ops.index[idx], y)
jax.ops.index_add(x, jax.ops.index[idx], y)
...

have long been deprecated in lieu of their more succinct counterparts:
x.at[idx].set(y)
x.at[idx].add(y)
...

This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.

The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays.

PiperOrigin-RevId: 400209692
2021-10-01 08:45:31 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Roy Frostig
9f8e3df320 fix advi example: pass tuples rather than lists to jitted function 2021-08-10 10:31:26 -07:00
slowy07
9eadb07bdc fix: miss typo codespell and documentation 2021-07-24 15:25:13 +07:00
Qiao Zhang
cf64f840e1 Remove redundant dep in jax cpp example target. 2021-06-22 14:49:19 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
jax authors
94e6314ba4 Merge pull request #6025 from bastings:patch-2
PiperOrigin-RevId: 363318137
2021-03-16 18:43:12 -07:00
Jasmijn Bastings
6470094f04
Fix import of datasets in differentially private SGD example 2021-03-11 11:32:43 +01:00
Skye Wanderman-Milne
cd619978bb Adjust precision in examples/kernel_lsq.py and corresponding test.
This is important on TPU.
2021-03-11 03:03:43 +00:00
jax authors
2e19b1d150 Merge pull request #5383 from zhangqiaorjc:jax_cpp
PiperOrigin-RevId: 353780826
2021-01-25 19:21:58 -08:00
Qiao Zhang
7e155244b6 Add example code to save JAX program and run using C++ runtime. 2021-01-25 19:12:26 -08:00
Peter Hawkins
160dfd343a Revert import path changes to examples/ and benchmarks/
PiperOrigin-RevId: 352911869
2021-01-20 17:35:55 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
Nicholas Vadivelu
238b4a1236
Update examples/differentially_private_sgd.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-10-05 00:54:18 -04:00
Nicholas Vadivelu
2ced1635b2 fix dp sgd example 2020-10-02 23:41:23 -07:00
Peter Hawkins
1b3dd65daf
Avoid lexically capturing the train_images value in MNIST VAE example. (#3947)
* Avoid lexically capturing the train_images value in MNIST VAE example.

This has the effect of baking in the training dataset as a constant, something that LLVM does not like that much.

* Add device_put to images.
2020-08-03 16:50:20 -04:00
Matthew Johnson
d10cf0e38f
fix prng key reuse in differential privacy example (#3646)
fix prng key reuse in differential privacy example
2020-07-02 14:29:17 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
6aa8f2461c
Fix remaining flakes and use exclude within setup.cfg (#3371) 2020-06-08 22:58:03 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Peter Hawkins
bc5a0b336b
Remove some uses of jax.partial. (#3131) 2020-05-18 10:19:03 -04:00
Peter Hawkins
d59ecddfe8
Replace np -> jnp, onp -> np in examples/ (#2971)
For context, see #2370
2020-05-05 15:45:07 -04:00
Jin Dong
60d856ab9f remove from __future__ code 2020-04-09 18:16:47 -04:00
Xiayun Sun
651316f4c7
Fix issue 1465: fix jit in example (#1473)
* fix jit in example

* Avoid using static_argnums on a keyword argument; use a positional argument and a wrapper function for now.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-04-01 11:57:57 -04:00
Matthew Johnson
b015e57169 try re-enabling control tests that trigger #2507 2020-03-30 20:12:33 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Srinivas Vasudevan
c7f211d433
Update JAX to use XLA hyperbolic functions. (#2415) 2020-03-19 10:29:37 -04:00
bddppq
ac6a313cfc
Fix ONNX mnist example (#2374)
* Fix ONNX mnist example

* use np to compute the shape; rename jax.numpy as jnp
2020-03-09 16:04:59 -04:00
Roy Frostig
6da6df0ae1 fix comment typo in MPC/LQR example 2020-02-26 08:00:27 -08:00
Roy Frostig
9cb8171fa7 remove unused imports in MPC/LQR example 2020-02-20 11:07:06 -08:00