908 Commits

Author SHA1 Message Date
Yash Katariya
1ad3551ec9 Release jax and jaxlib 0.3.0 as per the new release process.
PiperOrigin-RevId: 427809845
2022-02-10 11:59:13 -08:00
Lena Martens
0042edb5f4 Checkify: rename some symbols and add some docstrings. 2022-02-08 17:40:04 +00:00
Peter Hawkins
8be057de1f Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
2022-02-07 17:59:42 -05:00
Peter Hawkins
465b593293 Update scipy intersphinx inventory for SciPy 1.8.0.
According to https://github.com/scipy/scipy/issues/14267 the SciPy docs seems to have moved.
2022-02-07 16:19:46 -05:00
Jake VanderPlas
ea8817b329 DOC: move experimental APIs to their own pages 2022-02-04 14:40:34 -08:00
Jake VanderPlas
fc10438b4f DOC: move functions in jax.html to their own pages 2022-02-04 14:40:34 -08:00
jax authors
248572c3e8 Merge pull request #9446 from hawkinsp:docsj
PiperOrigin-RevId: 426408109
2022-02-04 08:23:21 -08:00
jax authors
2a61403845 Merge pull request #9440 from hawkinsp:booktheme
PiperOrigin-RevId: 426407795
2022-02-04 08:18:12 -08:00
Peter Hawkins
efacc93088 Use the sphinx-book-theme for JAX documentation. 2022-02-04 09:00:07 -05:00
Peter Hawkins
a43c82a3b7 Add -j auto to the suggested doc build instructions.
Despite being marked as experimental, parallelism appears to work fine for JAX doc builds, and speeds up builds significantly.
2022-02-04 08:57:14 -05:00
jax authors
45d96c490e Merge pull request #4671 from romanngg:conv_local
PiperOrigin-RevId: 426282505
2022-02-03 18:03:33 -08:00
Peter Hawkins
df55ea5204 Move design notes into docs/, and render them as part of the documentation. 2022-02-02 14:29:03 -05:00
Jake VanderPlas
b9b79bab31 maint: update pre-commit package versions & fix new mypy errors 2022-01-31 13:39:11 -08:00
jax authors
d66daa9039 Merge pull request #9312 from GeraldCSC:jit-tutorial-update
PiperOrigin-RevId: 424875694
2022-01-28 08:27:39 -08:00
Jake VanderPlas
928087ada0 DOC: add info about repeated indices to jax.ops docs 2022-01-27 10:22:25 -08:00
Gerald Shen
4833fbb672 fix caching example in jit tutorial 2022-01-26 12:03:40 -05:00
Jake VanderPlas
6c38ec9a05 developer doc: more info on pre-commit 2022-01-25 10:36:03 -08:00
Jake VanderPlas
42a2e66fbe DOC: pin docutils==0.16 to restore bullets in lists 2022-01-24 12:02:39 -08:00
jax authors
7f91763303 Merge pull request #9287 from google:gda
PiperOrigin-RevId: 423866601
2022-01-24 11:26:37 -08:00
George Necula
83b818d45c Add more documentation for buffer donation
Fixes: #9237
2022-01-24 09:33:08 +01:00
yashkatariya
aa8cd77876 Make line length equal 2022-01-21 15:58:03 -08:00
yashkatariya
c0212d4079 Add GDA to the API pages 2022-01-21 14:37:41 -08:00
Matthew Johnson
726f60f6bc remove cpu platform setting in quickstart
fixes #9244
2022-01-19 10:24:09 -08:00
Lena Martens
f591d0b2e9
Add ensure_compile_time_eval docstring to docs 2022-01-14 11:18:40 +00:00
Mike McCoy
2e5ab11652 Resolves issue 8744 2022-01-12 21:10:45 +00:00
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
jax authors
977e142d55 Merge pull request #9154 from che-shr-cat:patch-1
PiperOrigin-RevId: 420782381
2022-01-10 09:42:36 -08:00
che-shr-cat
d2c6c06546
Fix DeviceArray class reference 2022-01-10 17:34:09 +03:00
che-shr-cat
78977d6f5a fix broken links and update texts in thinking_in_jax.ipynb 2022-01-10 16:19:57 +03:00
Roman Novak
b9b759d4ff
Merge branch 'main' into conv_local 2022-01-07 09:51:46 -08:00
Jake VanderPlas
eba2ed2fd6 Update sphinx-related packages 2022-01-04 14:16:57 -08:00
jax authors
2e60850192 Merge pull request #9058 from che-shr-cat:main
PiperOrigin-RevId: 418917696
2021-12-30 01:39:40 -08:00
Grigory Sapunov
504728d8b6 link directly to the documentation for the jnp.ndarray.at property 2021-12-29 12:29:16 +03:00
Jake VanderPlas
b889282f6d docs: add FAQ section about jit compilation & numerics 2021-12-28 08:57:51 -08:00
Grigory Sapunov
f93531b020 replace deprecated jax.ops.index_* functions with the new index update operators 2021-12-27 20:23:29 +03:00
Vlad Feinberg
cd333f0f5b Fix straight-through estimator example in docs (#9032) 2021-12-21 22:25:12 +00:00
Jake VanderPlas
1f7d6316c2 doc: move stub section to bottom of FAQ 2021-12-15 16:19:14 -08:00
Matthew Johnson
0c68605bf1 add jax.block_until_ready to docs and changelog
also unrelatedly fix a couple of the uses of rst in changelog.md (though
many others remain)
2021-12-14 13:39:47 -08:00
jax authors
404c3c7d25 Merge pull request #8718 from jakevdp:config-doc
PiperOrigin-RevId: 413630185
2021-12-02 03:14:31 -08:00
jax authors
800aac8fd3 Merge pull request #8681 from jakevdp:numpy-faq
PiperOrigin-RevId: 413316336
2021-11-30 21:33:37 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Jake VanderPlas
0e4e30f4e5 DOC: add documentation for configuration functionality 2021-11-29 10:44:54 -08:00
Jake VanderPlas
4a72e57ce0 DOC: add FAQ section on JAX vs. Numpy performance 2021-11-24 12:04:02 -08:00
Matthew Johnson
8430deda3e custom pp_eqn rules, simpler xla_call print 2021-11-23 15:52:52 -08:00
Peter Hawkins
f3aa5fa92f Document lax.GatherScatterMode.
Recommend the .at[...] property in the docstrings for lax.scatter_ operators.

Add several missing lax.scatter_ operators to the index.
2021-11-22 15:43:02 -05:00
jax authors
f08a5a07a8 Merge pull request #8552 from mattjj:elide-more-convert-element-types
PiperOrigin-RevId: 411082070
2021-11-19 09:44:30 -08:00
Matthew Johnson
abbf78b5c3 generalize jaxpr simplification machinery
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts

This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.

But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.

In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.

These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
2021-11-19 09:00:59 -08:00
Peter Hawkins
58199b4b9a Delete the XLA in Python notebook.
Its tests are failing, and it describes a non-public API that we are phasing out.
2021-11-18 09:45:06 -05:00
Tianjian Lu
c5f73b3d8e [JAX] Added jax.lax.linalg.qdwh.
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00
Peter Hawkins
9ea55468ab [JAX] Update users of jax.ops.index... functions, which are deprecated.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 406162068
2021-10-28 09:54:26 -07:00