Yue Sheng
147c363ea6
Deprecate jax.clear_backends
.
...
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.
PiperOrigin-RevId: 616946337
2024-03-18 14:23:18 -07:00
Jake VanderPlas
236275ebe1
Deprecate jax.tree_map for jax v0.4.26
...
Reverts f4045dceb206be1ea10ee651ccc6151809f2d9f3
PiperOrigin-RevId: 611230367
2024-02-28 14:29:01 -08:00
Yash Katariya
f4045dceb2
Remove the deprecation of jax.tree_map for the release of 0.4.25
...
PiperOrigin-RevId: 610014256
2024-02-24 09:30:06 -08:00
Jake VanderPlas
e59a0506fe
Deprecate jax.tree_map in favor of jax.tree.map
2024-02-22 11:35:39 -08:00
Jake VanderPlas
cf80f574b5
Register jax.config module deprecation
...
PiperOrigin-RevId: 609352291
2024-02-22 06:38:56 -08:00
Sergei Lebedev
57e59eb6c3
Removed deprecated jax.config methods and jax.config.config
...
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7
PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Thomas Köppe
dcc65e621e
Reverts b506fee9e389391efb1336bc7575dba913e75cdf
...
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3
Removed deprecated jax.config methods and jax.config.config
...
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81
PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835
Reverts 318a19a89387caebd116168c4e47592e7d71ca65
...
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893
Removed deprecated jax.config methods
...
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
Jake VanderPlas
6934a4b76b
Add jax.tree module with aliases of jax.tree_util
2024-02-12 13:07:59 -08:00
Jake VanderPlas
e356d76913
Remove a number of deprecated APIs
...
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html ).
PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Jake VanderPlas
a52d18781e
Add experimental static key reuse checking
2023-12-11 12:03:48 -08:00
Roy Frostig
ed9a4c2939
add jax.threefry_partitionable
context manager
2023-10-31 13:45:55 -07:00
Sergei Lebedev
f2ce5dbd01
MAINT Do not use str()
and repr()
in f-string replacement fields
...
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
Jake VanderPlas
024b1f23d7
Remove deprecated submodule jax.abstract_arrays
2023-09-19 15:40:18 -07:00
Jake VanderPlas
1800015884
Import jax.version first
2023-09-12 12:27:20 -07:00
Jake VanderPlas
ca39457ea9
JEX: move jax.linear_util to jax.extend.linear_util
2023-08-30 18:32:12 -07:00
Muhammad Abdullah
599b35eeaa
Update __init__.py
2023-08-28 10:13:16 +05:00
Muhammad Abdullah
d09c55aa2d
Update __init__.py to include dlpack module
...
import for dlpack module was missing in `__init__.py` file. just added that
2023-08-28 10:01:17 +05:00
Jake VanderPlas
630a69f41b
[random] add jax_legacy_prng_key flag
2023-08-22 15:08:51 -07:00
Skye Wanderman-Milne
8b58e38ec5
Add jax_debug_log_modules
config option.
...
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-07-28 18:11:12 +00:00
Jake VanderPlas
2691d7edb9
Use standard framework for jax.tree* deprecation.
2023-07-20 12:58:17 -07:00
Peter Hawkins
651f87733b
Remove jax_jit_pjit_api_merge.
...
PiperOrigin-RevId: 548236671
2023-07-14 15:25:00 -07:00
Jake VanderPlas
21f6736005
Remove several deprecated APIs
2023-07-11 12:42:32 -07:00
Jake VanderPlas
47ae5bddd7
Mark jax.abstract_arrays as deprecated
2023-06-07 23:36:40 -07:00
Matthew Johnson
e0d2736e37
add custom_jvp for jax.nn.softmax
...
This avoids saving the jnp.exp(...) value.
2023-04-22 11:28:03 -07:00
Peter Hawkins
a377caec3a
Import jax.experimental.compilation_cache.compilation_cache by default.
...
This is to fix users who were relying on this module being imported as part of 'import jax'.
PiperOrigin-RevId: 525151996
2023-04-18 08:19:45 -07:00
Peter Hawkins
f8fe5d0542
Import jax.experimental.compilation_cache by default
...
PiperOrigin-RevId: 525033643
2023-04-17 21:28:35 -07:00
Matthew Johnson
26562a4382
[JAX] Add jax.clear_caches, plumb a way to clear pmap caches
...
fixes #10828
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 522654093
2023-04-07 12:19:00 -07:00
Peter Hawkins
c7b99e6ea9
Import jax.monitoring by default.
...
A JAX refactoring meant this was no longer being imported by default. Restore the previous state.
PiperOrigin-RevId: 522474571
2023-04-06 17:03:38 -07:00
Peter Hawkins
bf50551e0f
Explicitly import jax.custom_{batching,derivatives,transpose}.
...
https://github.com/google/jax/pull/15391 had the unintentional side effect of causing these names not to be imported by default. Restore the status quo by importing them.
PiperOrigin-RevId: 521898088
2023-04-04 16:40:15 -07:00
Peter Hawkins
c1f65fc8b2
Avoid imports from the public jax.* namespace in more places internally.
...
This change is in preparation for more cycle breaking in the Bazel dependency graph.
PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
abf1acf76c
Replace references to jax.interpreters with jax._src.interpreters in JAX core.
...
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Jake VanderPlas
fc47137ca8
Add deprecation warnings for several top-level jax imports
2023-03-28 12:40:59 -07:00
Peter Hawkins
6cc1bf54a1
Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
...
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.
PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
207cc10058
Error if jax_array or jax_jit_pjit_api_merge is set to False.
...
PiperOrigin-RevId: 517485597
2023-03-17 12:57:57 -07:00
Peter Hawkins
ed491b3056
Shorten alias chains for names exported in jax. namespace.
...
Add some additional type annotations on public APIs.
This allows pytype to do a better job of type inference.
PiperOrigin-RevId: 513255770
2023-03-01 09:19:44 -08:00
Peter Hawkins
b61d5d5654
Remove jax._src deletion.
...
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.
It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:
```
from jax._src.lib import xla_bridge
@mock.patch.object(xla_bridge, 'process_index')
...
```
A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:
```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```
However, this solution requires the `jax._src` be present in the JAX namespace.
Ideally users wouldn't mock our internals at all, but that requires significantly more work.
PiperOrigin-RevId: 512295203
2023-02-25 07:17:47 -08:00
Brennan Saeta
893d359933
Export the Shard
type.
...
PiperOrigin-RevId: 511615655
2023-02-22 15:38:08 -08:00
Peter Hawkins
cd0533cab0
Replace uses of jnp.ndarray with jax.Array inside JAX.
...
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
0c14e9ab49
Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
...
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.
Fix some overly-strict type annotations.
PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
Peter Hawkins
74f1ab0503
Export Device as jax.Device.
...
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00
Jake VanderPlas
43e57db77a
Begin deprecation of public jax.ShapedArray
2023-01-30 11:27:58 -08:00
Jake VanderPlas
f317943f56
Warn rather than fail when reloading JAX
...
Fixes https://github.com/google/jax/issues/13857
PiperOrigin-RevId: 500727768
2023-01-09 09:11:50 -08:00
Jongwook Choi
479f33680d
Make jax.ensure_compile_time_eval
correctly exposed as a public API
...
This function was added as a public API (#7987 ) but py.type static
checkers do not recognize it as a public API because of the alias name.
`jax.eval_context` exists only for backward compatibility, so the
correct import would be to import `ensure_compile_time_eval` directly
from `jax._src.core`.
2022-12-21 20:12:08 -05:00
Jake VanderPlas
4a6bbde409
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00
Roy Frostig
d927a5dbf3
migrate internal dependencies from jax.core
to jax._src.core
...
... in preparation for paring down `jax.core`'s exported symbols.
Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.
PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Jake VanderPlas
26f2f97805
Document why 'import name as name' is used
2022-12-14 15:07:04 -08:00