88 Commits

Author SHA1 Message Date
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Stephan Hoyer
458f6a6efe Add jax.test_util to public API docs 2025-01-23 16:04:35 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Dan Foreman-Mackey
092d2a0db5 Add error message when using custom_vmap with reverse-mode AD, and add docstrings.
The `custom_vmap` API is discussed in https://github.com/jax-ml/jax/issues/9073, and it remains somewhat experimental and incomplete, but it is sufficiently widely used that it seemed worth adding it to the docs.

One specific pain point with `custom_vmap` is that it doesn't support reverse-mode autodiff, so I also added a better error message for this case. Before this change, using `grad` with a `custom_vmap` function would fail with an `assert` deep within the JAX internals. This now fails with a `NotImplementedError` that describes the problem.

PiperOrigin-RevId: 704353963
2024-12-09 11:17:44 -08:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Dan Foreman-Mackey
1f0a04a4fc Add jax.make_mesh to API docs. 2024-10-09 13:55:43 -04:00
Jake VanderPlas
cf51ee7ef0 Improve documentation for jax.jacobian 2024-09-26 05:09:47 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
Jake VanderPlas
e9d6fd3795 document jax.Array methods and attributes 2024-08-16 06:37:19 -07:00
jax authors
88ed579704 Merge pull request #23039 from froystig:docs3
PiperOrigin-RevId: 663427005
2024-08-15 13:22:04 -07:00
Jake VanderPlas
91f5512965 Document methods of custom_jvp/custom_vjp 2024-08-14 15:37:20 -07:00
Roy Frostig
12eebfe8d4 docs: reorganize sections
* Create "extension guides" section
* Sort developer notes into subsections
* Move examples from advanced section into user guides
* Reorder some listings, adjust some titles
2024-08-13 21:33:45 -07:00
Dan Foreman-Mackey
60bf5b7727 Add a jax.process_indices function.
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.

PiperOrigin-RevId: 662142636
2024-08-12 10:30:41 -07:00
George Necula
105cc9a103 [export] Add documentation for jax.export 2024-06-12 19:44:47 +02:00
Yash Katariya
08b1cef00a Add jax.make_array_from_process_local_data to the docs. Preview: https://jax--21444.org.readthedocs.build/en/21444/_autosummary/jax.make_array_from_process_local_data.html#jax.make_array_from_process_local_data
PiperOrigin-RevId: 637660828
2024-05-27 09:41:43 -07:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Matthew Johnson
808958b69e add jax.custom_gradient to readthedocs 2023-11-20 13:26:52 -08:00
Roy Frostig
ca008f37e3 initiate jax.extend via docs and top-level module set-up 2023-05-15 15:47:06 -07:00
Jake VanderPlas
055edf4a08 DOC: add docstrings for callback functions 2023-04-12 07:33:09 -07:00
Matthew Johnson
057d408448 add docs for jax.clear_caches
Co-authored-by: Roy Frostig <frostig@google.com>
2023-04-07 14:42:31 -07:00
Jake VanderPlas
4a9ed3eaa8 Document ShapeDtypeStruct 2023-03-21 13:53:20 -07:00
Jake VanderPlas
11e32196cc DOC: add docs for jax.dtypes module 2023-02-14 11:18:59 -08:00
Jake VanderPlas
7975192f92 Expose jax.typing & update docs 2023-02-13 15:53:08 -08:00
Peter Hawkins
74f1ab0503 Export Device as jax.Device.
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00
Jake VanderPlas
81e627d5bd DOC: make API doc titles more uniform 2023-01-18 10:59:42 -08:00
Yash Katariya
e6c4d4a30e Add docstrings for Sharding classes. Right now I am only documenting Sharding, XLACompatibleSharding, MeshPspecSharding and SingleDeviceSharding.
Also moving jax_array_migration guide to reference documentation.

PiperOrigin-RevId: 488489503
2022-11-14 15:47:46 -08:00
Yash Katariya
6897d37562 Add docstrings for jax.Array APIs make_array_from_callback and make_array_from_single_device_arrays.
PiperOrigin-RevId: 487929688
2022-11-11 15:21:10 -08:00
Jake VanderPlas
0fb462efd7 Add jax.print_environment_info() 2022-09-12 15:39:33 -07:00
Roy Frostig
43db06491c write and generate package API documentation for jax.stages 2022-09-01 19:26:53 -07:00
Sharad Vikram
175823e066 Add callbacks to docs 2022-08-24 14:12:47 -07:00
Sharad Vikram
4386a0f909 Add debugging tools under jax.debug and documentation
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-07-28 20:07:26 -07:00
Sharad Vikram
c0b47fdf2c Update changelog for named_scope and adds it to the docs 2022-06-09 11:22:44 -07:00
Peter Hawkins
3e5ecfe363 Add jax.distributed and jax.dlpack to the docs.
Reorder the doc modules into something closer to alphabetical order.

Add missing functions from jax.scipy.linalg and jax.scipy.signal to the docs.
2022-02-17 16:10:07 -05:00
Jake VanderPlas
fc10438b4f DOC: move functions in jax.html to their own pages 2022-02-04 14:40:34 -08:00
Lena Martens
f591d0b2e9
Add ensure_compile_time_eval docstring to docs 2022-01-14 11:18:40 +00:00
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
Matthew Johnson
0c68605bf1 add jax.block_until_ready to docs and changelog
also unrelatedly fix a couple of the uses of rst in changelog.md (though
many others remain)
2021-12-14 13:39:47 -08:00
Jake VanderPlas
0e4e30f4e5 DOC: add documentation for configuration functionality 2021-11-29 10:44:54 -08:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Jake VanderPlas
ff2bfc0e87 DOC: fix docstring and add docs. 2021-09-27 09:48:27 -07:00
jax authors
aea51c83e4 Merge pull request #7188 from cccntu:add-doc-jax.device_get
PiperOrigin-RevId: 392521558
2021-08-23 14:39:37 -07:00
Jonathan Chang
8536780455 Add documentation for jax.device_get 2021-08-24 00:57:17 +08:00
Jake VanderPlas
ca1b273823 add ravel_pytree to generated docs 2021-08-02 17:55:41 -07:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
jax authors
14acd070c2 Internal change
PiperOrigin-RevId: 369345279
2021-04-19 18:23:07 -07:00
Skye Wanderman-Milne
b77ef5138b Replace host_id with process_index terminology.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
2021-04-19 14:09:19 -07:00
Jake VanderPlas
749ad95514 DOC: add transformations doc to HTML & reorganize contents 2021-03-08 16:25:04 -08:00
Jake VanderPlas
067be89a0c DOC: minor documentation & formatting fixes 2021-02-23 10:31:44 -08:00
Tom Hennigan
7adb1e381d Add jax.default_backend() which returns the default platform name.
This can be useful when you need backend specific behaviour, e.g.:

    if jax.default_backend() == 'gpu':
      dataset = double_buffer(dataset)

Or if you want to assert a given backend is the default:

    assert jax.default_backend() == 'tpu'

I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.
2021-02-04 14:50:15 +00:00
Roy Frostig
4adc4362ef include closure_convert in generated docs 2021-01-25 17:42:46 -08:00