2140 Commits

Author SHA1 Message Date
jax authors
4f70471310 Fix error in pallas tutorial
PiperOrigin-RevId: 737727935
2025-03-17 13:19:12 -07:00
Yash Katariya
3c0027af3b mixing modes 2025-03-14 18:23:27 -07:00
Jake VanderPlas
412b2e3acb Fix notebook formatting 2025-03-14 14:20:50 -07:00
Yash Katariya
aa9480a441 Expose get_abstract_mesh via the jax.sharding namespace
PiperOrigin-RevId: 736972976
2025-03-14 13:45:32 -07:00
Dougal
e8f43d1cef Explicit sharding docs 2025-03-14 16:33:30 -04:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
jax authors
ba367cdead Merge pull request #27044 from carlosgmartin:add_breadcrumbs_to_docs
PiperOrigin-RevId: 736283028
2025-03-12 15:14:17 -07:00
carlosgmartin
bc43b00d8f Add navigation breadcrumbs to docs. 2025-03-12 16:52:34 -04:00
Gunhyun Park
d191927b24 Fix syntax error and typos for composite primitive docstring.
PiperOrigin-RevId: 735808000
2025-03-11 10:37:07 -07:00
jax authors
b8590816bf Merge pull request #26839 from Sai-Suraj-27:fix_jax.debug.print
PiperOrigin-RevId: 735511953
2025-03-10 14:26:45 -07:00
Gary Miguel
6a718b762f
Update stateful-computations.md
tree_map -> tree.map
2025-03-09 21:35:46 -07:00
Jake Harmon
cdeeacabcf Update references to JAX's GitHub repo
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
jax authors
1a57fdf704 Fix convolution example (kernel should be OIHW, not IOHW).
PiperOrigin-RevId: 732952185
2025-03-03 09:32:22 -08:00
Sai-Suraj-27
56285aec6b Fixed printing order of results in jax.debug.print documentation. 2025-02-28 06:36:10 +00:00
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Jake VanderPlas
f5ca46f5ec Sharp bits: add note on subnormal flush-to-zero 2025-02-24 18:52:38 -08:00
Peter Hawkins
85add667ed Change documentation to recommend libtpu from pypi instead of GCS. 2025-02-24 17:50:29 -05:00
George Necula
1be801bac8 [better_errors] Cleanup use of DebugInfo.arg_names and result_paths
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Jake VanderPlas
380fb0a5a0 CI: skip execution of convolutions.ipynb
This is failing on readthedocs with a dead kernel error
2025-02-20 13:14:59 -08:00
rajasekharporeddy
180be99798 Fix typos 2025-02-18 17:01:29 +05:30
Dan Foreman-Mackey
c6c38fb852 Reorder top-level functions in lax.linalg, and add/expand docstrings.
PiperOrigin-RevId: 726603731
2025-02-13 12:57:55 -08:00
Dougal
9145366f6f Part 1 of a new autodidax based on "stackless" 2025-02-12 15:23:06 -05:00
Roy Frostig
8720a9c0cd docstrings and API reference doc listing for the traced AOT stage 2025-02-11 22:30:50 -08:00
jax authors
914adaf60c Merge pull request #26476 from froystig:aot-doc-traced
PiperOrigin-RevId: 725902103
2025-02-11 22:01:21 -08:00
Roy Frostig
af381a73a3 update AOT walkthrough to cover explicit tracing stage 2025-02-11 21:26:05 -08:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
Jiyoun (Jen) Ha
e6f38b7fb3 [pallas] fix typo
PiperOrigin-RevId: 725244824
2025-02-10 09:32:07 -08:00
Dan Foreman-Mackey
21281f377e Remove the tutorial about legacy custom calls. 2025-02-06 21:04:30 -05:00
Jake VanderPlas
e0ee415677 DOC: update numpy/scipy versions in deprecation doc 2025-02-06 15:18:11 -08:00
Jake VanderPlas
2fb750e0ab doc: improve docs for jax.lax trig functions 2025-02-06 11:09:55 -08:00
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Qazalbash
7fc605f783
Merge branch 'main' into scipy-expon 2025-02-05 23:33:51 +05:00
Jake VanderPlas
e4dac395a5 Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653

Reverts 95535df13b422284043623ca3a6d2a5962116fb1

PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
jax authors
6281b86008 Merge pull request #26289 from reikdas:reikdas/docfix
PiperOrigin-RevId: 722878611
2025-02-03 18:29:41 -08:00
Pratyush Das
72b6704d88 documentation: fix function name 2025-02-03 19:09:00 -05:00
jax authors
95535df13b Merge pull request #25688 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
jax authors
12c76cdeaa Merge pull request #25474 from jaro-sevcik:compilation-cache-mock-doc
PiperOrigin-RevId: 722681339
2025-02-03 09:07:31 -08:00
Qazalbash
46c5b865b7
docs: Update jax.scipy.rst to include exponential distribution functions in documentation 2025-02-01 12:55:49 +05:00
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
Axel Donath
dd7c203b81 Adjust conf.py to remove prompt from sphinx copybutton 2025-01-31 13:51:27 -05:00
Dan Foreman-Mackey
e2eff1f8d5 Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions. 2025-01-29 11:12:32 -05:00
jax authors
bf22b53cf4 Merge pull request #26154 from jakevdp:pure-callback-doc
PiperOrigin-RevId: 720763192
2025-01-28 17:28:02 -08:00
Jake VanderPlas
25aa5a3008 DOC: avoid deprecated argument in external callbacks 2025-01-28 11:34:56 -08:00
Jake VanderPlas
ba2858f834 DOC: add discussion of exceptions in pure_callback 2025-01-28 09:53:47 -08:00
jax authors
6004a501ad Merge pull request #26129 from jakevdp:osx-support
PiperOrigin-RevId: 720246018
2025-01-27 11:32:37 -08:00
Jake VanderPlas
7eacce0b83 DOC: update osx x86 entry in installation grid 2025-01-27 11:18:33 -08:00
Dan Foreman-Mackey
782138fb6f Add custom_dce to changelogs and API docs. 2025-01-27 13:03:34 -05:00
jax authors
726abc9c31 Merge pull request #26082 from Saransh-cpp:index-update-syntax-link
PiperOrigin-RevId: 719356167
2025-01-24 10:37:07 -08:00