jax authors
4f70471310
Fix error in pallas tutorial
...
PiperOrigin-RevId: 737727935
2025-03-17 13:19:12 -07:00
Yash Katariya
3c0027af3b
mixing modes
2025-03-14 18:23:27 -07:00
Jake VanderPlas
412b2e3acb
Fix notebook formatting
2025-03-14 14:20:50 -07:00
Yash Katariya
aa9480a441
Expose get_abstract_mesh
via the jax.sharding
namespace
...
PiperOrigin-RevId: 736972976
2025-03-14 13:45:32 -07:00
Dougal
e8f43d1cef
Explicit sharding docs
2025-03-14 16:33:30 -04:00
jax authors
bf829ff612
Merge pull request #26524 from carlosgmartin:random_multinomial
...
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Yash Katariya
2d01226b3b
Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
...
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
carlosgmartin
6b69a136aa
Add jax.random.multinomial.
2025-03-12 18:15:14 -04:00
jax authors
ba367cdead
Merge pull request #27044 from carlosgmartin:add_breadcrumbs_to_docs
...
PiperOrigin-RevId: 736283028
2025-03-12 15:14:17 -07:00
carlosgmartin
bc43b00d8f
Add navigation breadcrumbs to docs.
2025-03-12 16:52:34 -04:00
Gunhyun Park
d191927b24
Fix syntax error and typos for composite primitive docstring.
...
PiperOrigin-RevId: 735808000
2025-03-11 10:37:07 -07:00
jax authors
b8590816bf
Merge pull request #26839 from Sai-Suraj-27:fix_jax.debug.print
...
PiperOrigin-RevId: 735511953
2025-03-10 14:26:45 -07:00
Gary Miguel
6a718b762f
Update stateful-computations.md
...
tree_map -> tree.map
2025-03-09 21:35:46 -07:00
Jake Harmon
cdeeacabcf
Update references to JAX's GitHub repo
...
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax
PiperOrigin-RevId: 733536104
2025-03-04 18:31:09 -08:00
jax authors
1a57fdf704
Fix convolution example (kernel should be OIHW, not IOHW).
...
PiperOrigin-RevId: 732952185
2025-03-03 09:32:22 -08:00
Sai-Suraj-27
56285aec6b
Fixed printing order of results in jax.debug.print documentation.
2025-02-28 06:36:10 +00:00
Tom Hennigan
1becb57ac9
Add jax.copy_to_host_async(tree)
.
...
A relatively common pattern I've observed is the following:
```python
_, metrics = some_jax_function()
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:
```python
_, metrics = some_jax_function()
# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)
with profiler.Trace('compute_metrics'):
jax.block_until_ready(metrics)
with profiler.Trace('copy_to_host'):
metrics = jax.device_get(metrics)
```
PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Jake VanderPlas
f5ca46f5ec
Sharp bits: add note on subnormal flush-to-zero
2025-02-24 18:52:38 -08:00
Peter Hawkins
85add667ed
Change documentation to recommend libtpu from pypi instead of GCS.
2025-02-24 17:50:29 -05:00
George Necula
1be801bac8
[better_errors] Cleanup use of DebugInfo.arg_names and result_paths
...
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Jake VanderPlas
380fb0a5a0
CI: skip execution of convolutions.ipynb
...
This is failing on readthedocs with a dead kernel error
2025-02-20 13:14:59 -08:00
rajasekharporeddy
180be99798
Fix typos
2025-02-18 17:01:29 +05:30
Dan Foreman-Mackey
c6c38fb852
Reorder top-level functions in lax.linalg, and add/expand docstrings.
...
PiperOrigin-RevId: 726603731
2025-02-13 12:57:55 -08:00
Dougal
9145366f6f
Part 1 of a new autodidax based on "stackless"
2025-02-12 15:23:06 -05:00
Roy Frostig
8720a9c0cd
docstrings and API reference doc listing for the traced AOT stage
2025-02-11 22:30:50 -08:00
jax authors
914adaf60c
Merge pull request #26476 from froystig:aot-doc-traced
...
PiperOrigin-RevId: 725902103
2025-02-11 22:01:21 -08:00
Roy Frostig
af381a73a3
update AOT walkthrough to cover explicit tracing stage
2025-02-11 21:26:05 -08:00
Jake VanderPlas
e389b707ba
Add public APIs for jax.lax monoidal reductions
2025-02-11 16:00:03 -08:00
Jiyoun (Jen) Ha
e6f38b7fb3
[pallas] fix typo
...
PiperOrigin-RevId: 725244824
2025-02-10 09:32:07 -08:00
Dan Foreman-Mackey
21281f377e
Remove the tutorial about legacy custom calls.
2025-02-06 21:04:30 -05:00
Jake VanderPlas
e0ee415677
DOC: update numpy/scipy versions in deprecation doc
2025-02-06 15:18:11 -08:00
Jake VanderPlas
2fb750e0ab
doc: improve docs for jax.lax trig functions
2025-02-06 11:09:55 -08:00
Michael Hudgins
2e808f2836
Merge pull request #26279 from MichaelHudgins:tsan-resultstore
...
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Qazalbash
7fc605f783
Merge branch 'main' into scipy-expon
2025-02-05 23:33:51 +05:00
Jake VanderPlas
e4dac395a5
Roll back multinomial change from https://github.com/jax-ml/jax/pull/25688
...
This has test breakages on TPU: https://github.com/jax-ml/jax/actions/runs/13159081976/job/36723019653
Reverts 95535df13b422284043623ca3a6d2a5962116fb1
PiperOrigin-RevId: 723536107
2025-02-05 09:13:56 -08:00
jax authors
6281b86008
Merge pull request #26289 from reikdas:reikdas/docfix
...
PiperOrigin-RevId: 722878611
2025-02-03 18:29:41 -08:00
Pratyush Das
72b6704d88
documentation: fix function name
2025-02-03 19:09:00 -05:00
jax authors
95535df13b
Merge pull request #25688 from carlosgmartin:random_multinomial
...
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
jax authors
12c76cdeaa
Merge pull request #25474 from jaro-sevcik:compilation-cache-mock-doc
...
PiperOrigin-RevId: 722681339
2025-02-03 09:07:31 -08:00
Qazalbash
46c5b865b7
docs: Update jax.scipy.rst
to include exponential distribution functions in documentation
2025-02-01 12:55:49 +05:00
carlosgmartin
32411a430f
Add jax.random.multinomial.
2025-01-31 18:45:55 -05:00
Axel Donath
dd7c203b81
Adjust conf.py to remove prompt from sphinx copybutton
2025-01-31 13:51:27 -05:00
Dan Foreman-Mackey
e2eff1f8d5
Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions.
2025-01-29 11:12:32 -05:00
jax authors
bf22b53cf4
Merge pull request #26154 from jakevdp:pure-callback-doc
...
PiperOrigin-RevId: 720763192
2025-01-28 17:28:02 -08:00
Jake VanderPlas
25aa5a3008
DOC: avoid deprecated argument in external callbacks
2025-01-28 11:34:56 -08:00
Jake VanderPlas
ba2858f834
DOC: add discussion of exceptions in pure_callback
2025-01-28 09:53:47 -08:00
jax authors
6004a501ad
Merge pull request #26129 from jakevdp:osx-support
...
PiperOrigin-RevId: 720246018
2025-01-27 11:32:37 -08:00
Jake VanderPlas
7eacce0b83
DOC: update osx x86 entry in installation grid
2025-01-27 11:18:33 -08:00
Dan Foreman-Mackey
782138fb6f
Add custom_dce to changelogs and API docs.
2025-01-27 13:03:34 -05:00
jax authors
726abc9c31
Merge pull request #26082 from Saransh-cpp:index-update-syntax-link
...
PiperOrigin-RevId: 719356167
2025-01-24 10:37:07 -08:00