117 Commits

Author SHA1 Message Date
Jake VanderPlas
21f6736005 Remove several deprecated APIs 2023-07-11 12:42:32 -07:00
Jake VanderPlas
47ae5bddd7 Mark jax.abstract_arrays as deprecated 2023-06-07 23:36:40 -07:00
Matthew Johnson
e0d2736e37 add custom_jvp for jax.nn.softmax
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Peter Hawkins
a377caec3a Import jax.experimental.compilation_cache.compilation_cache by default.
This is to fix users who were relying on this module being imported as part of 'import jax'.

PiperOrigin-RevId: 525151996
2023-04-18 08:19:45 -07:00
Peter Hawkins
f8fe5d0542 Import jax.experimental.compilation_cache by default
PiperOrigin-RevId: 525033643
2023-04-17 21:28:35 -07:00
Matthew Johnson
26562a4382 [JAX] Add jax.clear_caches, plumb a way to clear pmap caches
fixes #10828

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 522654093
2023-04-07 12:19:00 -07:00
Peter Hawkins
c7b99e6ea9 Import jax.monitoring by default.
A JAX refactoring meant this was no longer being imported by default. Restore the previous state.

PiperOrigin-RevId: 522474571
2023-04-06 17:03:38 -07:00
Peter Hawkins
bf50551e0f Explicitly import jax.custom_{batching,derivatives,transpose}.
https://github.com/google/jax/pull/15391 had the unintentional side effect of causing these names not to be imported by default. Restore the status quo by importing them.

PiperOrigin-RevId: 521898088
2023-04-04 16:40:15 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Jake VanderPlas
fc47137ca8 Add deprecation warnings for several top-level jax imports 2023-03-28 12:40:59 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
207cc10058 Error if jax_array or jax_jit_pjit_api_merge is set to False.
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
Peter Hawkins
ed491b3056 Shorten alias chains for names exported in jax. namespace.
Add some additional type annotations on public APIs.

This allows pytype to do a better job of type inference.

PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
Peter Hawkins
b61d5d5654 Remove jax._src deletion.
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.

It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:

```
from jax._src.lib import xla_bridge

@mock.patch.object(xla_bridge, 'process_index')
...
```

A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:

```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```

However, this solution requires the `jax._src` be present in the JAX namespace.

Ideally users wouldn't mock our internals at all, but that requires significantly more work.

PiperOrigin-RevId: 512295203
2023-02-25 07:17:47 -08:00
Brennan Saeta
893d359933 Export the Shard type.
PiperOrigin-RevId: 511615655
2023-02-22 15:38:08 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
0c14e9ab49 Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.

Fix some overly-strict type annotations.

PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
Peter Hawkins
74f1ab0503 Export Device as jax.Device.
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Jake VanderPlas
f317943f56 Warn rather than fail when reloading JAX
Fixes https://github.com/google/jax/issues/13857

PiperOrigin-RevId: 500727768
2023-01-09 09:11:50 -08:00
Jongwook Choi
479f33680d Make jax.ensure_compile_time_eval correctly exposed as a public API
This function was added as a public API (#7987) but py.type static
checkers do not recognize it as a public API because of the alias name.

`jax.eval_context` exists only for backward compatibility, so the
correct import would be to import `ensure_compile_time_eval` directly
from `jax._src.core`.
2022-12-21 20:12:08 -05:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Yash Katariya
ca1f58e37b Add a new jax.spmd_mode config for preventing unintentional hangs and incorrect results when users pass jax.Arrays that span across multiple processes (i.e. not fully addressable) to jit or jnp operations (that are jitted by default).
Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.

Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.

PiperOrigin-RevId: 485075693
2022-10-31 09:51:42 -07:00
Kuangyuan Chen
57eb19f3ea Add a warning to device.live_buffers() as it is going to be deprecated with jax.Array and instruct users to use jax.live_arrays() instead.
PiperOrigin-RevId: 484533292
2022-10-28 08:11:51 -07:00
Albert Alonso
6ddbe5d5ec expose jax.block_until_ready() from jax 2022-10-13 18:31:47 +02:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Jake VanderPlas
0cb233eec9 Add initial jax.Array base class for instance checks & annotation 2022-09-26 07:48:43 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
da90234cae Delete soft_pmap as it has no users. Please use pjit or xmap if you do want soft_pmap.
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided.

PiperOrigin-RevId: 474145090
2022-09-13 15:52:10 -07:00
Jake VanderPlas
0fb462efd7 Add jax.print_environment_info() 2022-09-12 15:39:33 -07:00
Peter Hawkins
40c80d7d0a Remove jax._src from JAX namespace.
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX.

PiperOrigin-RevId: 473243243
2022-09-09 07:06:00 -07:00
Sharad Vikram
393bca122d Expose pure callback and enable rank polymorphic callbacks 2022-08-17 10:56:42 -07:00
jax authors
560c936a46 Merge pull request #11653 from sharadmv:debugging-docs
PiperOrigin-RevId: 463988525
2022-07-28 20:26:25 -07:00
Sharad Vikram
4386a0f909 Add debugging tools under jax.debug and documentation
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-07-28 20:07:26 -07:00
Lena Martens
8ca5ecc7f3 Re-land #11498 after internal fixes.
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 463885774
2022-07-28 11:33:34 -07:00
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
jax authors
023e6f5955 Copybara import of the project:
--
e1f1e93e0c8b53e62a064b06b56c84a2bfedb911 by Roy Frostig <frostig@google.com>:

maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 461146464
2022-07-15 01:23:51 -07:00
Roy Frostig
e1f1e93e0c maintain an alias to jax.tree_util.tree_map in the top level jax module 2022-07-14 11:00:54 -07:00
Jake VanderPlas
ce08a9fc5c Deprecate top-level aliases of jax.tree_util functions 2022-07-07 11:41:46 -07:00
Sharad Vikram
289610eb02 Add a public facing named_scope function to allow adding to the name stack. 2022-06-08 17:23:57 -07:00
jax authors
ea54754c49 Merge pull request #9118 from skye:device_context_manager
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
Jake VanderPlas
ceae6fe5e2 Add jax_numpy_dtype_promotion='strict' mode 2022-05-26 10:56:09 -07:00