19 Commits

Author SHA1 Message Date
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Jake VanderPlas
29942e312b docs: add another example to the ConcretizationTypeError docs 2022-12-05 11:24:54 -08:00
Yash Katariya
a419e1917a Use jax.Array by default for doctests
PiperOrigin-RevId: 488719467
2022-11-15 11:52:22 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
e3a92d52ba prepare to switch to new remat
This commit involves a few things, which are all united in being about landing
the new remat (aka new checkpoint) implementation:
  * add benchmarks for new remat eager performance, and some caching to make those
    benchmarks fast
  * warn when the old-remat-exclusive `concrete` feature is used, with an
    actionable message pointing to the new recommended approach involving static_argnums
  * add the static_argnums parameter to both new and old remt
  * update docstrings (and de-duplicate them to)
  * add new tests, especially around caching and errors/warnings
2022-08-04 12:25:03 -07:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
9ea55468ab [JAX] Update users of jax.ops.index... functions, which are deprecated.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 406162068
2021-10-28 09:54:26 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Jake VanderPlas
00f36173bd Specify weak_type in DeviceArray repr 2021-08-23 13:19:33 -07:00
lenamartens
f966e7ef5c Fix some of the formatting and reword some of the sections. 2021-07-28 13:20:52 +01:00
Lena Martens
201cca8c42
Apply suggestions from code review
Co-authored-by: Roy Frostig <froystig@users.noreply.github.com>
2021-07-28 11:55:32 +01:00
Lena Martens
19ee7b22e1 Expose UnexpectedTracerError and add docs. 2021-07-27 23:23:28 +01:00
Matthew Johnson
d21e8c0657 handle case where trace debug_info is None 2021-05-06 18:38:20 -07:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Jake VanderPlas
0796bfe6e7 errors: add NonConcreteBooleanIndexError & debugging tips 2021-03-23 11:23:20 -07:00
Jake VanderPlas
e9195ba626 Fix URL in custom errors 2021-03-16 09:10:10 -07:00
Jake VanderPlas
12c84e7a50 Add jax.errors submodule & error troubleshooting docs 2021-03-03 12:39:12 -08:00