50 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
d07951c592 jnp.einsum_path: improve docs & annotations 2024-05-10 08:39:32 -07:00
Jake VanderPlas
c3d3db9b0e jnp.einsum: support optimize=False, and improve docs for this keyword. 2024-05-09 19:50:06 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
0da3a7ffb5 jnp.einsum: lower to mixed-precision dot_general when possible.
This is a re-landing of https://github.com/google/jax/pull/16733. The downstream issues should be fixed by https://github.com/google/jax/pull/17152.

Reverts c6f40e202c7f5724b9be61afa33541a8f4abfdd0

PiperOrigin-RevId: 559794120
2023-08-24 10:31:39 -07:00
Parker Schuh
c6f40e202c Reverts 75c3457264f9cc117ff09551ce3174d72689fa3d
PiperOrigin-RevId: 557628297
2023-08-16 16:06:28 -07:00
Jake VanderPlas
14d52fca55 jnp.einsum: lower to mixed-precision dot_general when possible 2023-08-15 15:57:19 -07:00
Matthew Johnson
6bdb5821c3 einsum: inf inputs could cause superfluous nan outputs 2023-07-11 17:19:40 -07:00
Peter Hawkins
803c729b57 Fix some test failures under H100.
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.

Happily, this also allows us to remove a number of TPU special cases for the same reason.

PiperOrigin-RevId: 539101155
2023-06-09 09:23:36 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
3eb61e1252 jnp.einsum: add preferred_element_type argument 2023-04-08 09:24:39 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
6114e6a0d3 test_util: add decorator to set config values in test cases 2021-08-05 14:06:37 -07:00
Jake VanderPlas
768aba55f1 disable implicit rank promotion in lax_numpy_einsum/indexing/vectorize_test 2021-08-03 12:19:36 -07:00
Peter Hawkins
9832df8ada Try to avoid transposes in jnp.einsum by considering both argument orders to dot_general. 2021-03-04 10:00:08 -05:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00
Peter Hawkins
7116cc5b41
Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)
* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
2020-05-04 23:00:20 -04:00
Peter Hawkins
c758aff88b
Fix some missing cases of broadcasting in np.einsum. (#2512)
* Fix some missing cases of broadcasting in np.einsum.

In particular, np.einsum allows one side of a batch or contracting dimension to have size 1 even if the other side has a non-1 size.

Implement np.matmul in terms of np.einsum. This allows us to reuse einsum's logic for performing broadcasting without explicitly broadcasting the LHS and RHS together.

* Add regression test.

Fixes #2189.
2020-04-01 10:54:47 -04:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
96677d9c6f
Use highest precision for einsum test. (#1876)
Fixes test failures on TPU which uses lower precision by default.
2019-12-17 11:45:39 -05:00
Peter Hawkins
b26a12a358
Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.do… (#1872)
* Implement bool_ support for jnp.add, jnp.multiply, jnp.einsum, lax.dot and lax.dot_general.

Fix dtype rules for `lax._reduce_sum` and `lax._reduce_prod` to check for number inputs.

Improve error messages for type mismatches to correctly describe scalar type categories (e.g. 'floating') rather than what `onp.dtype(...).name` returns (e.g., 'float64').

Remove redundant `bfloat16` type in `lax._float`, which has been redundant since `dtypes.issubdtype` was taught about `bfloat16` support.
2019-12-16 20:48:19 -05:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… (#1701)
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
Matthew Johnson
108a2dbb9c tweak an einsum test 2019-04-08 09:52:47 -07:00
Matthew Johnson
46e26a790a add a comment about kpmurphy einsum test 2019-04-06 15:26:33 -07:00
Matthew Johnson
f4e141d30e add 'optimize' kwarg to jax.numpy.einsum 2019-04-06 15:26:33 -07:00
Matthew Johnson
1be9abd322 add jax.numpy.einsum_path (fixes #579) 2019-04-06 10:33:18 -07:00
Peter Hawkins
b06f08e66b Enable the remaining einsum tests.
Fixes #37.
2019-02-08 13:24:08 -05:00
Peter Hawkins
20c9737359 Update jaxlib references to 0.1.7.
Enable einsum tests (issue #37).
2019-02-08 11:48:32 -05:00
Matthew Johnson
52254d7602 tweaks so einsum and spstats tests run internally 2019-02-05 07:52:56 -08:00
Anselm Levskaya
1d26c6bfa1
add test-case for nondeterminism in front batch dim equal case in einsum
This is a test for the observed (and hopefully fixed) nondeterminism in the case of already-front, already-ordered batch_dims.
2019-02-04 16:30:22 -08:00
Matthew Johnson
6bb9609fb8 disable test, py3 opt_einsum nondeterministic bug? 2018-12-19 16:58:31 -08:00
Matthew Johnson
6a138202ef fix several einsum bugs 2018-12-19 16:15:43 -08:00
Matthew Johnson
2b934dae22 einsum test cases from dask (thanks @sjperkins)
The new cases are based on these test cases from dask:
https://github.com/dask/dask/pull/3412
2018-12-19 11:17:14 -08:00
Matthew Johnson
8ba29c2dfc einsum test cases from dask (thanks @sjperkins)
The new cases are based on these test cases from dask:
https://github.com/dask/dask/pull/3412
2018-12-19 11:15:45 -08:00
Matthew Johnson
9a68bce567 add comment marking a bug 2018-12-19 10:42:40 -08:00
Matthew Johnson
166f45bf2b add tests for cases tf.einsum doesn't handle
from https://www.tensorflow.org/api_docs/python/tf/einsum
one currently fails
2018-12-19 10:42:40 -08:00
Matthew Johnson
997c9c5a50 fix einsum tensor product logic (fixes #37)
The error was that `lhs_names` and `rhs_names` included `batch_names` as
prefixes, but the reshaping logic was written as if they did not include
batch_names (and so batch_names had to be prepended).
2018-12-19 07:59:00 -08:00
Matthew Johnson
6261ef729a more einsum improvements (complete?) 2018-12-18 23:20:10 -08:00
Matthew Johnson
d8388e2d80 complete support for two-operand einsum 2018-12-18 23:20:10 -08:00
Matthew Johnson
fdde6841e6 add support for two-operrand cases 2018-12-18 23:20:10 -08:00
Matthew Johnson
13a0e1168e fix broadcasted eye bug, enable more einsum 2018-12-18 23:20:10 -08:00
Matthew Johnson
6a71e9d6ec start drafting an einsum implementation 2018-12-18 23:20:09 -08:00