1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 07:56:06 +00:00

32 Commits

Author SHA1 Message Date
Sergei Lebedev
0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Mikhail Grankin
0b1b812e3b bugfix for sm3 optimizer
Ensure atleast 1d input. Optimizer should work now with scalar data.
fixes 
2021-08-04 20:56:11 +03:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis ()
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
2020-06-02 19:25:47 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… ()
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests () 2020-06-01 11:49:35 -07:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. () 2020-05-05 14:59:16 -04:00
Jacob Kelly
61fc2bf2c1 add adamax test 2020-04-19 16:02:17 -04:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. ()
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
42dd736afd
Change scalar promotion rules to prefer array types over scalar types. ()
* Change scalar promotion rules to prefer array types over scalar types.

Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.

This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.

In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… ()
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
Trevor Cai
ae52f865f5 Add polynomial_decay schedule to optimizers 2019-08-30 18:33:58 +01:00
Peter Hawkins
d2f2da29f8 Enable some tests that now seem to pass. 2019-08-14 09:05:55 -04:00
George Dahl
444aced791 Fix pack_optimizer_state to correctly use tuples everywhere in the packed state and add a unit test to check round trip unpack/pack. 2019-08-02 16:42:17 -07:00
Matthew Johnson
09b229969f patch utility functions in batching.py
fixes 
2019-05-22 16:00:43 -07:00
Rohan Anil
a08f7ad5fa Adagrad optimizer 2019-05-13 20:36:45 -07:00
Matthew Johnson
9788a3584a bump version for pypi, leave DeviceTuples off 2019-05-07 15:50:22 -07:00
Matthew Johnson
9fc47e51f5 use DeviceTuples in optimizer states again 2019-05-06 22:43:46 -07:00
Matthew Johnson
e751189366 add optimizer utilities (fixes and ) 2019-05-06 16:10:10 -07:00
Matthew Johnson
bf6c15b59a update pmap to flatten correctly (was a perf bug)
also temporarily avoid DeviceTuples in optimizer states
2019-05-06 12:09:54 -07:00
Matthew Johnson
8b3baf25c0 add sm3 (PAIR w/ @lukaszkaiser and rohananil@) 2019-05-03 14:56:52 -07:00
Matthew Johnson
642d2dc802 revies optimizers api, fix misc bugs
* add more optimizers numerical tests
* update examples and readme with new optimziers api
* add device_values parameter to xla_call
* change optimizers.py to flatten trees and subtrees
* remove tree_map2, tree_multimap2, tree_mimomap, tree_prefixmap
* add optimizer tests: DeviceTuples and error msgs
* make the device_values arg to jit private
2019-05-03 12:44:52 -07:00
Matthew Johnson
5ae8ab46c8 rename test file too 2019-02-06 11:05:21 -08:00