114 Commits

Author SHA1 Message Date
Jake VanderPlas
e376df29be disable implicit rank promotion in a number of remaining tests 2022-01-28 08:16:30 -08:00
Jake VanderPlas
3f1d21ad73 examples tests: avoid use of private jax utilities 2021-12-10 11:42:36 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
02d7d837e6 [JAX] Replace uses of deprecated jax.ops.index_update(x, idx, y) APIs with their up-to-date, more succinct equivalent x.at[idx].set(y).
The JAX operators:
jax.ops.index_update(x, jax.ops.index[idx], y)
jax.ops.index_add(x, jax.ops.index[idx], y)
...

have long been deprecated in lieu of their more succinct counterparts:
x.at[idx].set(y)
x.at[idx].add(y)
...

This change updates users of the deprecated APIs to use the current APIs, in preparation for removing the deprecated forms from JAX.

The main subtlety is that if `x` is not a JAX array, we must cast it to one using `jnp.asarray(x)` before using the new form, since `.at[...]` is only defined on JAX arrays.

PiperOrigin-RevId: 400209692
2021-10-01 08:45:31 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Roy Frostig
9f8e3df320 fix advi example: pass tuples rather than lists to jitted function 2021-08-10 10:31:26 -07:00
slowy07
9eadb07bdc fix: miss typo codespell and documentation 2021-07-24 15:25:13 +07:00
Qiao Zhang
cf64f840e1 Remove redundant dep in jax cpp example target. 2021-06-22 14:49:19 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
jax authors
94e6314ba4 Merge pull request #6025 from bastings:patch-2
PiperOrigin-RevId: 363318137
2021-03-16 18:43:12 -07:00
Jasmijn Bastings
6470094f04
Fix import of datasets in differentially private SGD example 2021-03-11 11:32:43 +01:00
Skye Wanderman-Milne
cd619978bb Adjust precision in examples/kernel_lsq.py and corresponding test.
This is important on TPU.
2021-03-11 03:03:43 +00:00
jax authors
2e19b1d150 Merge pull request #5383 from zhangqiaorjc:jax_cpp
PiperOrigin-RevId: 353780826
2021-01-25 19:21:58 -08:00
Qiao Zhang
7e155244b6 Add example code to save JAX program and run using C++ runtime. 2021-01-25 19:12:26 -08:00
Peter Hawkins
160dfd343a Revert import path changes to examples/ and benchmarks/
PiperOrigin-RevId: 352911869
2021-01-20 17:35:55 -08:00
Peter Hawkins
929a684a39 Small cleanups to dependency structure.
PiperOrigin-RevId: 352853244
2021-01-20 12:43:28 -08:00
Nicholas Vadivelu
238b4a1236
Update examples/differentially_private_sgd.py
Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-10-05 00:54:18 -04:00
Nicholas Vadivelu
2ced1635b2 fix dp sgd example 2020-10-02 23:41:23 -07:00
Peter Hawkins
1b3dd65daf
Avoid lexically capturing the train_images value in MNIST VAE example. (#3947)
* Avoid lexically capturing the train_images value in MNIST VAE example.

This has the effect of baking in the training dataset as a constant, something that LLVM does not like that much.

* Add device_put to images.
2020-08-03 16:50:20 -04:00
Matthew Johnson
d10cf0e38f
fix prng key reuse in differential privacy example (#3646)
fix prng key reuse in differential privacy example
2020-07-02 14:29:17 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
6aa8f2461c
Fix remaining flakes and use exclude within setup.cfg (#3371) 2020-06-08 22:58:03 -07:00
Matthew Johnson
c42a7f7890
remove some trailing whitespace (#3287) 2020-06-02 17:37:20 -07:00
Peter Hawkins
bc5a0b336b
Remove some uses of jax.partial. (#3131) 2020-05-18 10:19:03 -04:00
Peter Hawkins
d59ecddfe8
Replace np -> jnp, onp -> np in examples/ (#2971)
For context, see #2370
2020-05-05 15:45:07 -04:00
Jin Dong
60d856ab9f remove from __future__ code 2020-04-09 18:16:47 -04:00
Xiayun Sun
651316f4c7
Fix issue 1465: fix jit in example (#1473)
* fix jit in example

* Avoid using static_argnums on a keyword argument; use a positional argument and a wrapper function for now.

Co-authored-by: Peter Hawkins <phawkins@google.com>
2020-04-01 11:57:57 -04:00
Matthew Johnson
b015e57169 try re-enabling control tests that trigger #2507 2020-03-30 20:12:33 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Srinivas Vasudevan
c7f211d433
Update JAX to use XLA hyperbolic functions. (#2415) 2020-03-19 10:29:37 -04:00
bddppq
ac6a313cfc
Fix ONNX mnist example (#2374)
* Fix ONNX mnist example

* use np to compute the shape; rename jax.numpy as jnp
2020-03-09 16:04:59 -04:00
Roy Frostig
6da6df0ae1 fix comment typo in MPC/LQR example 2020-02-26 08:00:27 -08:00
Roy Frostig
9cb8171fa7 remove unused imports in MPC/LQR example 2020-02-20 11:07:06 -08:00
Roy Frostig
f02578ea9e set tolerances and toggle dtype checks in MPC/LQR example tests 2020-02-19 19:55:47 -08:00
Roy Frostig
f1af9893a4 move MPC/LQR example to examples directory 2020-02-19 13:47:18 -08:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Matthew Johnson
8bca2c90e7 fix urllib import for py3 2020-01-09 20:25:42 -08:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Roy Frostig
1648435268 enable kernel regression example test 2019-11-20 14:25:09 -08:00
Jonas Rauber
f8c5d98653 fixed cross-entropy losses in mnist examples (fixes #1023) 2019-10-27 09:43:46 +01:00
Karthik Kumara
1090e89a86 - sign missing from loss function definition 2019-10-17 16:44:27 -07:00
Nicolas Papernot
a411384d48
typo 2019-08-09 17:18:31 -07:00
Fabian Pedregosa
b3df11d64a
Typo in docstring
jax.experimentaloptimizers -> jax.experimental.optimizers
2019-08-06 09:07:48 -04:00
Jasper Snoek
06d41fbad1 Added a note about squared distances 2019-05-30 12:13:10 -04:00
Jasper Snoek
3d1e419e0d Updating GP regression example 2019-05-29 16:36:45 -04:00
Jasper Snoek
58c0221505 Updating GP regression example 2019-05-29 16:35:39 -04:00