50 Commits

Author SHA1 Message Date
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Matthew Johnson
6ba0ef6505 relax tanh test tols for upcoming xla change 2022-12-20 21:06:09 -08:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Jake VanderPlas
243b931e28 BUG: fix jet rule for dynamic_slice 2022-11-02 21:37:41 -07:00
Jake VanderPlas
db7eea1f60 add jet rule for dynamic_update_slice 2022-11-02 14:47:11 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
57b5acf1b6 Roll forward: Upgrade logistic into a primitive.
Unlike the previous attempt, we don't try to use mhlo.logistic as the lowering of the new primitive yet. Instead, we lower to the old implementation of `expit`. This means that this change should be a no-op numerically and we can work on changing its implementation in a subsequent change.

PiperOrigin-RevId: 472705623
2022-09-07 06:06:56 -07:00
jax authors
9c16c83234 Rollback of upgrade logistic (sigmoid) function into a lax primitive.
PiperOrigin-RevId: 471105650
2022-08-30 15:30:43 -07:00
Peter Hawkins
f68f1c0cd0 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 470300985
2022-08-26 11:58:28 -07:00
jax authors
3e3542b0d6 Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469841487
2022-08-24 15:39:37 -07:00
Peter Hawkins
6276194e1c Upgrade logistic (sigmoid) function into a lax primitive.
This allows us to lower it to `mhlo.logistic`, which allows XLA to generate more efficient code.

PiperOrigin-RevId: 469789339
2022-08-24 12:04:01 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Jake VanderPlas
893179fcfc [x64] make jet_test compatible with strict dtype promotion 2022-06-21 09:28:24 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Nicholas Krämer
c07a0f1139 Add test and jet-primitive for dynamic_slice 2022-02-08 13:28:41 +01:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Roman Novak
b65f39ca7a Default to jnp.float_ type in nn.initializers. 2021-08-17 20:41:09 -07:00
Skye Wanderman-Milne
5fe8712249 Adjust JetTest.test_scatter_add tolerance for TPU 2021-03-31 17:25:27 +00:00
Matthew Johnson
aa2472db0c add scatter_add jet rule, fixes #5365
could use a better test though...
2021-03-30 14:04:40 -07:00
Skye Wanderman-Milne
3c3dc0f6fc Adjust tolerances in jet_test.py for TPU.
Most of these tests are disabled on TPU and can probably be enabled, but I just fixed the currently-enabled tests for now.
2021-03-03 23:49:20 +00:00
Matthew Johnson
d7b5e3b5d4 add add_any to jet rules table
fixes #5217
2020-12-17 12:10:12 -08:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
Cambridge Yang
fe9f264b55
cumulative jet rules (#4000) 2020-08-11 07:09:54 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jacob Kelly
575216e094
add jet primitives, refactor tests (#3468)
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
2020-06-16 19:48:25 -07:00
Matthew Johnson
270e921489
Merge pull request #3456 from jacobjinkelly/int_pow
Jet rule for `integer_pow`
2020-06-15 22:17:25 -07:00
Jacob Kelly
f463598f19 add int pow rule 2020-06-15 17:23:57 -04:00
Jacob Kelly
3cf6b1de54 add erf inv rule
erf_inv rule not working

works up to order 2

erf inv rule

use np for now

actually use np for now
2020-06-15 15:40:31 -04:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Jacob Kelly
83a339e161
add erf and erfc rules (#3051)
refactor def comp
2020-05-19 15:22:25 -07:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00
Jacob Kelly
a821e67d60
instantiate zeros (#2924)
fix dtype

remove TODO
2020-05-01 17:10:20 -07:00
Jacob Kelly
1f7ebabfc8
add jets for sines fns (#2892)
refactor

remove duplicate
2020-04-29 19:18:21 -07:00
Jacob Kelly
fc4203c38a
implement jet rules by lowering to other primitives (#2816)
merge jet_test

add jet rules

use lax.square
2020-04-23 22:07:35 -07:00
Jacob Kelly
59bdb1fb3d
add tanh rule (#2653)
change expit taylor rule

add manual expit check, check stability of expit and tanh
2020-04-22 17:49:10 -07:00
Peter Hawkins
1298e9e8c4
Fix some test failures. (#2713) 2020-04-14 18:23:19 -04:00
Jacob Kelly
8503656ea8 add finite test, add sep lims for binary_check 2020-04-09 19:57:01 -04:00
Jacob Kelly
1fa0e8a67d jet of pow using comp with exp, mul, log
Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-04-08 21:51:44 -04:00
Jacob Kelly
4d7b63c5ec add expm1 and log1p 2020-04-07 17:55:07 -04:00
Matthew Johnson
84dc6cc1c4 post process call of jet!
Also included David's jet rule for lax.select.

Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu>
Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com>
Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu>
2020-04-02 07:56:26 -07:00
David Duvenaud
ead8011837 Added lots of trivial jet rules.
Co-Authored-By: jessebett <jessebett@gmail.com>
Co-Authored-By: Jacob Kelly <jacob.kelly@mail.utoronto.ca>
2020-03-29 16:28:17 -04:00
Matthew Johnson
a00e3986d4 remove scipy dep, fix dtype issue 2020-03-15 12:00:44 -07:00
Matthew Johnson
a7b3be71e8 move jet into jax.experimental 2020-03-15 11:10:56 -07:00
Matthew Johnson
668a1703bc add jet tests, remove top-level files 2020-03-14 21:22:10 -07:00