41 Commits

Author SHA1 Message Date
Jake VanderPlas
067be89a0c DOC: minor documentation & formatting fixes 2021-02-23 10:31:44 -08:00
Tom Hennigan
7adb1e381d Add jax.default_backend() which returns the default platform name.
This can be useful when you need backend specific behaviour, e.g.:

    if jax.default_backend() == 'gpu':
      dataset = double_buffer(dataset)

Or if you want to assert a given backend is the default:

    assert jax.default_backend() == 'tpu'

I am a bit conflicted by the naming, "backend" is consistent with other APIs in
JAX (e.g. jit, local_devices etc) which accept a "backend" string which is used
to lookup an XLA backend by platform name.
2021-02-04 14:50:15 +00:00
Roy Frostig
4adc4362ef include closure_convert in generated docs 2021-01-25 17:42:46 -08:00
Lena Martens
d7743140a3
Add named_call autofunction description to docs 2020-11-30 14:23:52 +01:00
Qiumin Xu
31600aac62 Add named_call public API.
Move named_call_p to core.py from lax.py.
Also move the translation rule to jax/interpreters/xla.py where the core_call translation rule is.
2020-11-12 17:32:01 -08:00
jax authors
ffff3a42fc Merge pull request #4828 from j-towns:api-doc-fixes
PiperOrigin-RevId: 341395571
2020-11-09 06:46:48 -08:00
Jamie Townsend
a0a2c973e6 Fixes to jax.api docs 2020-11-08 22:22:36 +00:00
Jamie Townsend
931b2ddbcb Improve docs for custom_jvp and custom_vjp
Correct the custom_jvp docstring to include the defjvps instance method. Add the
defjvp/defvjp instance methods to the sphinx doc.
2020-11-06 11:40:16 +00:00
Stephan Hoyer
877053d8ab
Add jax.linear_transpose (#3398)
* Add jax.linear_transpose

Co-authored-by: Matthew Johnson <mattjj@google.com>

* add failing test for complex numbers

* Add picky dtype check for linear_transpose

* Lint fix

* Allow truncating dtypes to match inputs in linear_transpose

* Fix typo in shape check error

* improve docstring

* Don't support integer inputs; better docstring

* fixup

* Fix doctest

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-09-16 20:29:19 -07:00
Stephan Hoyer
816bcd7196
Fix jax.checkpoint in API docs (#3980)
On the index API doc page, it turns out functions need to be listed *twice* to
appear.
2020-08-06 09:39:33 -07:00
Peter Hawkins
b943b31b22
Add jax.image.resize. (#3703)
* Add jax.image.resize.

This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator.

While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises.
2020-07-10 09:57:59 -04:00
Matthew Johnson
93054fa1d3
add remat to top-level docs (#3554) 2020-06-25 07:26:26 -07:00
Matthew Johnson
ae9df752de add docstring to ravel_pytree 2020-06-12 15:41:07 -07:00
Stephan Hoyer
5a0bf46234
DOC: add a table of contents for top level API docs (#2946)
This makes them easier to scan.
2020-05-04 12:37:29 -07:00
Matthew Johnson
83bf048f8a separate out deprecated custom_transforms stuff 2020-03-23 16:45:28 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Skye Wanderman-Milne
a1fa6296cc
Document jax.device_put (#2366) 2020-03-05 14:45:01 -08:00
Peter Hawkins
ddd76b803b
Add a jax.profiler module that exposes the new TensorFlow profiler in… (#2354)
* Add a jax.profiler module that exposes the new TensorFlow profiler integration.

* Fix documentation.
2020-03-05 12:07:57 -05:00
Stephan Hoyer
48f2a41453
Minor fixes to docs related to jax.numpy.vectorize (#2278)
- Show `numpy.jax.vectorize` explicitly in the JAX docs, rather than the
  original `numpy.vectorize.
- Updated regex for identifying function signatures in NumPy. This now correctly
  parses `np.vectorize` and `np.einsum`.
- Removed docs for `jax.experimental.vectorize`. There's still some good
  narrative content in the docstring but it should go somewhere else.
2020-02-23 19:10:39 +01:00
George Necula
938336e08a
Merge pull request #2216 from gnecula/documentation
Added the first draft of the Jaxpr documentation.
2020-02-14 07:23:47 +01:00
Stephan Hoyer
00140f07e2
Add jax.numpy.vectorize (#2146)
* Add jax.numpy.vectorize

This is basically a non-experimental version of the machinery in
`jax.experimental.vectorize`, except:
- It adds the `excluded` argument from NumPy, which works just like
  `static_argnums` in `jax.jit`.
- It doesn't include the `axis` argument yet (which NumPy doesn't have).

Eventually we might want want to consolidate the specification of signatures
with signatures used by shape-checking machinery, but it's nice to emulate
NumPy's existing interface, and this is already useful (e.g., for writing
vectorized linear algebra routines).

* Add deprecation warning to jax.experimental.vectorize

* improve implementation
2020-02-12 14:09:37 -08:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
Peter Hawkins
91cd20b173
Update documentation and changelog to mention DLPack and array interface support. (#2134) 2020-01-31 11:15:04 -05:00
Skye Wanderman-Milne
6f3cb1c3ee Add jax.devices(), etc. to the docs. 2019-11-22 11:03:42 -08:00
James Bradbury
cc49d8b325 add docs for jax.nn 2019-08-29 18:15:36 -07:00
Stephan Hoyer
ee54f6553c Fix index of API docs
The current index of the API docs page seems to have broken links: when I
click on "Automatic differentiation" for example, I get sent to the "JIT"
section.

This change fixes the links.
2019-08-20 23:13:15 -07:00
Peter Hawkins
4eb1820ae2 Add documentation to JAX modules. 2019-07-21 15:55:47 -04:00
Matthew Johnson
0f1d913db3 add xla_computation to jax rst docs 2019-07-05 16:55:11 -07:00
Matthew Johnson
bbf625e0df fix jax.rst docs (remove defvjp2 / defjvp2) 2019-06-11 06:52:55 -07:00
Matthew Johnson
ab20f0292c add docstring for defjvp_all 2019-06-05 17:34:14 -07:00
Matthew Johnson
720dec4072 add custom_gradient 2019-06-05 13:48:04 -07:00
Matthew Johnson
ffec059f0e add jax.eval_shape to reference docs via jax.rst 2019-06-01 09:53:32 -07:00
Matthew Johnson
d199471696 add pmap to sphinx docs
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-05-15 08:25:20 -07:00
Peter Hawkins
8c2c1a2e71 Expose jax.tree_util in the JAX docs. 2019-05-14 21:00:27 -04:00
Matthew Johnson
f5b4391f38 add jax.linearize to jax.readthedocs.io 2019-03-25 11:27:29 -07:00
Peter Hawkins
8b5e09f10a Add new functions jax.ops.index_add and jax.ops.index_update for NumPy-style indexed updates.
Create a new library `jax.ops` for user-facing ops that don't exist in NumPy or SciPy.

Progress on issue #101. Fixes #122.

Reenable some disabled TPU indexing tests that now pass.
2019-03-04 15:13:14 -05:00
Peter Hawkins
67174e3d57 Document jax.disable_jit. Add an example to jax.grad. 2019-02-20 09:00:12 -05:00
Peter Hawkins
b714cb30cc Add documentation for jax.jvp and jax.vjp. 2019-02-19 22:08:14 -05:00
Peter Hawkins
15e6c27130 Improve JAX API docs.
Add examples for `jax.jit`, `jax.jacfwd`, and `jax.jacrev`.
Document `jax.hessian`. Add `argnums` support to `jax.hessian`.
2019-02-15 08:16:25 -05:00
Matthew Johnson
9cf24029b2 add jax.random entry to jax.rst 2019-02-13 19:31:41 -08:00
Peter Hawkins
86d8915c3d Add Sphinx-generated reference documentation for JAX. 2019-01-16 09:13:31 -05:00