82 Commits

Author SHA1 Message Date
Peter Hawkins
40c80d7d0a Remove jax._src from JAX namespace.
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX.

PiperOrigin-RevId: 473243243
2022-09-09 07:06:00 -07:00
Sharad Vikram
393bca122d Expose pure callback and enable rank polymorphic callbacks 2022-08-17 10:56:42 -07:00
jax authors
560c936a46 Merge pull request #11653 from sharadmv:debugging-docs
PiperOrigin-RevId: 463988525
2022-07-28 20:26:25 -07:00
Sharad Vikram
4386a0f909 Add debugging tools under jax.debug and documentation
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-07-28 20:07:26 -07:00
Lena Martens
8ca5ecc7f3 Re-land #11498 after internal fixes.
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 463885774
2022-07-28 11:33:34 -07:00
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
George Necula
ab7d036271 Remove dependencies on masking.py 2022-07-25 11:25:26 +03:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
jax authors
023e6f5955 Copybara import of the project:
--
e1f1e93e0c8b53e62a064b06b56c84a2bfedb911 by Roy Frostig <frostig@google.com>:

maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 461146464
2022-07-15 01:23:51 -07:00
Roy Frostig
e1f1e93e0c maintain an alias to jax.tree_util.tree_map in the top level jax module 2022-07-14 11:00:54 -07:00
Jake VanderPlas
ce08a9fc5c Deprecate top-level aliases of jax.tree_util functions 2022-07-07 11:41:46 -07:00
Sharad Vikram
289610eb02 Add a public facing named_scope function to allow adding to the name stack. 2022-06-08 17:23:57 -07:00
jax authors
ea54754c49 Merge pull request #9118 from skye:device_context_manager
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
Jake VanderPlas
ceae6fe5e2 Add jax_numpy_dtype_promotion='strict' mode 2022-05-26 10:56:09 -07:00
Sharad Vikram
268b4be21b Add output token for unordered effects
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657.

PiperOrigin-RevId: 449106281
2022-05-16 18:56:33 -07:00
Skye Wanderman-Milne
f26b866e08 Add jax.default_device context manager
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.

Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
2022-05-07 00:31:00 +00:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Sharad Vikram
d72a7b4054 Add version int tuple __version_info__ to JAX 2022-04-05 13:26:05 -07:00
Roy Frostig
e0ddd43933 add custom_vjp upgrade flag to jax module and fix doc rendering for upgrade flags 2022-03-28 20:36:35 -07:00
Roy Frostig
047488446b factor AOT types out to a stages module 2022-03-15 15:18:15 -07:00
Hyeontaek Lim
beaa00c460 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 428590081
2022-02-14 13:11:49 -08:00
Peter Hawkins
74506c7dda Rollback of: Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; s...

PiperOrigin-RevId: 427576107
2022-02-09 14:44:45 -08:00
Hyeontaek Lim
b7e1fec250 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 427562278
2022-02-09 13:50:25 -08:00
George Necula
5bfe1852a4 [jax2tf] Add jax2tf_associative_scan_reductions flag
This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
2022-01-13 15:52:18 +02:00
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
Matthew Johnson
c8a34fe5cc add jax.block_until_ready function
fixes #8536
2021-12-14 11:02:14 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Roy Frostig
cdc9df2b66 add Lowered and Compiled to the jax public package 2021-11-10 18:57:39 -08:00
Qiao Zhang
0be30fbf96 Add jax.distributed.initialize for multi-host GPU. 2021-10-26 14:37:54 -07:00
Roy Frostig
98d245ebb4 add a config setting to control the default PRNG implementation
Also add explicit seeding functions for each PRNG implementation.
2021-10-07 21:22:40 -07:00
Jake VanderPlas
37719e4ad5 Import numpy explicitly in the jax namespace 2021-09-23 09:02:14 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
f35ab3693d Remove jax.partial from the JAX API.
Use functools.partial instead.
2021-09-20 09:19:53 -04:00
Peter Hawkins
6a1b626564 Remove jax.api.
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00
Jake VanderPlas
33e2bed1b4 Fix package exports 2021-09-14 13:55:55 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
Matthew Johnson
2d28951ba4 address comments form @apaszke 2021-08-26 14:10:58 -07:00
Jake VanderPlas
1a397a5572 import in the JAX namespace 2021-08-24 08:44:05 -07:00
Roy Frostig
60e0e9f929 implement backwards-compatible behavior and enable custom PRNGs only conditionally
Introduce a config flag for upgrading to a world of custom PRNGs. The
flag defaults off, so that we can introduce custom PRNGs into the
codebase and allow downstream libraries time to upgrade.

Backwards compatible behavior is meant in an external sense. This does
not mean that our code is internally the same any longer.
2021-08-19 20:43:11 -07:00
Peter Hawkins
46cc654537 Move jax.abstract_arrays to jax._src.abstract_arrays.
PiperOrigin-RevId: 377044255
2021-06-02 06:25:22 -07:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00