2161 Commits

Author SHA1 Message Date
Matthew Johnson
97a0d752c9 [shard-map] add docs for VMAs
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2025-04-09 21:02:57 +00:00
jax authors
b8d9e7f427 Merge pull request #27503 from kaixih:enable_doc_scaled_dot
PiperOrigin-RevId: 745322012
2025-04-08 15:50:54 -07:00
Sergei Lebedev
62df2e8d86 Added jax.no_tracing to the API docs
PiperOrigin-RevId: 745247778
2025-04-08 12:32:35 -07:00
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Adam Paszke
511f78202f Add a skeleton for Pallas:Mosaic GPU documentation 2025-04-08 13:13:51 +00:00
Meesum Qazalbash
12b1a99ad9
fix(docs): corrected the name of the function call in the document 2025-04-04 11:35:46 +05:00
Olli Lupton
297a4f42de docs: compilation_cache_expect_pgle option 2025-04-02 14:42:10 +00:00
Roy Frostig
6fe6d80506 upgrade docs from jax.core to jax.extend.core where needed to fix doc build 2025-04-01 22:17:53 -07:00
Jake VanderPlas
dafebd0d7f DOC: add documentation note about default dtypes 2025-03-28 15:20:58 -07:00
jax authors
358c55d066 Update instructions for usage of :build_jaxlib=false flag.
By adding [jax wheel testing](https://github.com/jax-ml/jax/pull/27113) functionality, we need to have pre-built jax and jaxlib wheels.

PiperOrigin-RevId: 741249718
2025-03-27 12:54:47 -07:00
kaixih
f949b8b8f6 Enable public doc for scaled dot 2025-03-27 00:05:28 +00:00
Parker Schuh
6033592a95 Rename xla_extension_version to jaxlib_extension_version to reflect its new
scope.

PiperOrigin-RevId: 740944270
2025-03-26 16:36:34 -07:00
jax authors
79ece131dc Merge pull request #27404 from mattbahr:add-pascal-matrix
PiperOrigin-RevId: 740913011
2025-03-26 14:54:20 -07:00
Matt Bahr
81abbac536 add pascal matrix 2025-03-26 02:11:03 +00:00
jax authors
6144b371f7 Merge pull request #27355 from jlperla:identity
PiperOrigin-RevId: 740422195
2025-03-25 11:29:55 -07:00
jax authors
4da1faf5b6 Move PGLE documentation to JAX docs.
PiperOrigin-RevId: 739865595
2025-03-24 02:50:45 -07:00
Jesse Perla
5d79df7e67 Add identity activation
Fix typo
2025-03-23 15:13:17 -07:00
Arno Eigenwillig
7f0f185abd In JEP-12049, fix link to EAFP in the Python glossary:
the anchor became mixed-case as of Python 3.10.

PiperOrigin-RevId: 739150752
2025-03-21 05:56:43 -07:00
jax authors
4da751a97a Reverts e0c093314d8d9a6f68953f0c340c1b01d50ce386
PiperOrigin-RevId: 738662342
2025-03-19 21:51:26 -07:00
jax authors
e0c093314d Remove ; in code blocks of thinking_in_jax.md
PiperOrigin-RevId: 738656531
2025-03-19 21:25:10 -07:00
Yash Katariya
dde861af5f Remove the jax Array migration guide from the TOC tree but keep the doc around
PiperOrigin-RevId: 738421256
2025-03-19 09:05:45 -07:00
jax authors
4f70471310 Fix error in pallas tutorial
PiperOrigin-RevId: 737727935
2025-03-17 13:19:12 -07:00
Yash Katariya
3c0027af3b mixing modes 2025-03-14 18:23:27 -07:00
Jake VanderPlas
412b2e3acb Fix notebook formatting 2025-03-14 14:20:50 -07:00
Yash Katariya
aa9480a441 Expose get_abstract_mesh via the jax.sharding namespace
PiperOrigin-RevId: 736972976
2025-03-14 13:45:32 -07:00
Dougal
e8f43d1cef Explicit sharding docs 2025-03-14 16:33:30 -04:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
jax authors
ba367cdead Merge pull request #27044 from carlosgmartin:add_breadcrumbs_to_docs
PiperOrigin-RevId: 736283028
2025-03-12 15:14:17 -07:00
carlosgmartin
bc43b00d8f Add navigation breadcrumbs to docs. 2025-03-12 16:52:34 -04:00
Gunhyun Park
d191927b24 Fix syntax error and typos for composite primitive docstring.
PiperOrigin-RevId: 735808000
2025-03-11 10:37:07 -07:00
jax authors
b8590816bf Merge pull request #26839 from Sai-Suraj-27:fix_jax.debug.print
PiperOrigin-RevId: 735511953
2025-03-10 14:26:45 -07:00
Gary Miguel
6a718b762f
Update stateful-computations.md
tree_map -> tree.map
2025-03-09 21:35:46 -07:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
jax authors
1a57fdf704 Fix convolution example (kernel should be OIHW, not IOHW).
PiperOrigin-RevId: 732952185
2025-03-03 09:32:22 -08:00
Sai-Suraj-27
56285aec6b Fixed printing order of results in jax.debug.print documentation. 2025-02-28 06:36:10 +00:00
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Jake VanderPlas
f5ca46f5ec Sharp bits: add note on subnormal flush-to-zero 2025-02-24 18:52:38 -08:00
Peter Hawkins
85add667ed Change documentation to recommend libtpu from pypi instead of GCS. 2025-02-24 17:50:29 -05:00
George Necula
1be801bac8 [better_errors] Cleanup use of DebugInfo.arg_names and result_paths
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Jake VanderPlas
380fb0a5a0 CI: skip execution of convolutions.ipynb
This is failing on readthedocs with a dead kernel error
2025-02-20 13:14:59 -08:00
rajasekharporeddy
180be99798 Fix typos 2025-02-18 17:01:29 +05:30
Dan Foreman-Mackey
c6c38fb852 Reorder top-level functions in lax.linalg, and add/expand docstrings.
PiperOrigin-RevId: 726603731
2025-02-13 12:57:55 -08:00
Dougal
9145366f6f Part 1 of a new autodidax based on "stackless" 2025-02-12 15:23:06 -05:00
Roy Frostig
8720a9c0cd docstrings and API reference doc listing for the traced AOT stage 2025-02-11 22:30:50 -08:00
jax authors
914adaf60c Merge pull request #26476 from froystig:aot-doc-traced
PiperOrigin-RevId: 725902103
2025-02-11 22:01:21 -08:00
Roy Frostig
af381a73a3 update AOT walkthrough to cover explicit tracing stage 2025-02-11 21:26:05 -08:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Jiyoun (Jen) Ha
e6f38b7fb3 [pallas] fix typo
PiperOrigin-RevId: 725244824
2025-02-10 09:32:07 -08:00