1380 Commits

Author SHA1 Message Date
Jake VanderPlas
7bb8312f82 CI: update jupytext to v0.14.7 2023-07-24 11:51:45 -07:00
Jake VanderPlas
7d7a536b55 custom prng: introduce mechanism to identify key arrays by dtype 2023-07-21 12:27:32 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Roy Frostig
9150b239ff add jax.prng to uncovered modules list in API policy 2023-07-18 14:13:25 -07:00
Roy Frostig
9aa5307e2f API compatibility policy: expand on numerics and randomness 2023-07-18 14:13:25 -07:00
Peter Hawkins
f540ae4338 Fix warning about direct invocation of setup.py during jaxlib build.
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.

To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.

PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
TJ
2432aa97a9 Change default memory allocation to 75% instead of 90% 2023-07-13 14:44:38 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
Roy Frostig
14e38a3f9d AOT doc: fix lower/compile expression in error example 2023-07-10 18:27:06 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
f67acee129 Merge pull request #16430 from jakevdp:bool-error
PiperOrigin-RevId: 542951181
2023-06-23 14:00:12 -07:00
Peter Hawkins
bfa113ba60 Remove references to Python 3.8.
Remove the old build scripts/Dockerfile, since they are unused and broken.

PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Jake VanderPlas
f1e603e4b3 errors: create TracerBoolConversionError for more targeted debugging tips 2023-06-21 01:41:45 -07:00
Peter Hawkins
34d9f5a9ae Add a CI presubmit that renders the documentation. 2023-06-20 09:29:25 -04:00
Peter Hawkins
5ce6748e2f Document how to enable concurrent kernel tracing on GPU. 2023-06-16 13:20:27 -04:00
Jake VanderPlas
d7a19442b6 DOC: fix formatting in FAQ 2023-06-14 03:17:08 -07:00
Tom Hennigan
ed073aa6c9 Add jax.tree_util.tree_leaves_with_path(tree).
PiperOrigin-RevId: 539609052
2023-06-12 04:13:37 -07:00
jax authors
8d27f20637 Merge pull request #16246 from chrisflesher:scipy-rotation-v3
PiperOrigin-RevId: 538788621
2023-06-08 08:10:58 -07:00
jax authors
6518e4e34c Merge pull request #16271 from jakevdp:abstract-array-deprecation
PiperOrigin-RevId: 538763490
2023-06-08 06:07:22 -07:00
Chris Flesher
5be17ed90c Added scipy.spatial.transform Rotation and Slerp classes 2023-06-08 07:51:32 -05:00
Jake VanderPlas
3fc70d3d8b Typo: remove stray tick in jax.numpy docs 2023-06-08 01:15:21 -07:00
Jake VanderPlas
47ae5bddd7 Mark jax.abstract_arrays as deprecated 2023-06-07 23:36:40 -07:00
Jake VanderPlas
3bef6214bb Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct 2023-06-02 04:10:46 -07:00
jax authors
e99045381d Update mentioning of DeviceArray and ShardedDeviceArray to jax.Array in the parallelism tutorial
`jax.Array` is now a unified type for all kinds of arrays.

PiperOrigin-RevId: 537155869
2023-06-01 16:12:59 -07:00
jax authors
ae78de1a49 Merge pull request #16189 from skye:profiling_docs
PiperOrigin-RevId: 537046864
2023-06-01 09:34:46 -07:00
Jake VanderPlas
e5cd69479b DOC: fix doc formatting 2023-06-01 03:37:21 -07:00
ivyzheng
6bf1cbc667 Add key path related guide & code to the documentation. 2023-05-31 20:15:56 -07:00
Skye Wanderman-Milne
1d1429fe8b Update profiling docs.
* Mention that Tensorboard profiling supports device memory usage
* Recommend TB profiling instead of the pprof-based device memory profiling
* Minor updates to GCP instructions

Inspired by https://github.com/google/jax/issues/1491
2023-05-30 14:27:11 -07:00
Jake VanderPlas
333ff4abbc Add jnp.matrix_transpose() and jax.Array.mT
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Kevin Millikin
921fd222bf Refer to the original map/zip classes via builtins
Referring to them as simply `map` or `zip` will create recursive
reimplementations (with no base case!) if the cell is reevaluated in
the same runtime.
2023-05-24 07:47:50 +01:00
jax authors
85fb48a33c Merge pull request #15930 from canyon289:jax201
PiperOrigin-RevId: 534149169
2023-05-22 12:34:48 -07:00
Ravin Kumar
473fa7d670 Add building on JAX 2023-05-22 10:05:39 -07:00
jax authors
bb775c7ce1 Merge pull request #15871 from nouiz:doc
PiperOrigin-RevId: 533434343
2023-05-19 06:08:01 -07:00
Roy Frostig
ca008f37e3 initiate jax.extend via docs and top-level module set-up 2023-05-15 15:47:06 -07:00
Roy Frostig
ce840a9cd8 JEP: jax.extend, a module for extensions 2023-05-05 13:50:22 -07:00
Frederic Bastien
de57b4fd36 Fix a sphinx error. 2023-05-05 11:08:18 -07:00
Frederic Bastien
c1b532eda8 Remove one fct from doc per review. 2023-05-05 11:08:17 -07:00
Frederic Bastien
decdbfb166 Document jax.experimental.multihost_utils 2023-05-05 11:08:17 -07:00
jax authors
5d143e6eea Merge pull request #15818 from froystig:random-bits-direct
PiperOrigin-RevId: 529090390
2023-05-03 07:56:17 -07:00
Roy Frostig
ea3389205f add jax.random.bits 2023-05-03 06:10:05 -07:00
David Pizzuto
6948d32d15 contributing: Switch repo URL to HTTPS for consistency with other github URLs. 2023-05-01 10:03:39 -07:00
Jake VanderPlas
e059e3b52f DOC: document jax.experimental.sparse.linalg 2023-04-28 14:18:50 -07:00
Jake VanderPlas
8dc06ed2ce Document jax.lax.with_sharding_constraint 2023-04-26 10:19:04 -07:00
jax authors
70ebdb0502 Jax101 - Clarify that the compiled code is executed on first call
The current wording implies that the first time a jitted function
is called, the computation happens in Python. It's actually only
the tracing that happens in Python, and the compiled code is run
during the first call. The distinction is important e.g., to
understand why it might make sense to jit a function that's only
called once.

PiperOrigin-RevId: 526906176
2023-04-25 02:41:12 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
jax authors
975e76ef76 Merge pull request #15664 from skye:tpu_install
PiperOrigin-RevId: 525605301
2023-04-19 18:18:32 -07:00
jax authors
1de4d14da8 Merge pull request #15656 from laqua-stack:add-special-gamma-fcn
PiperOrigin-RevId: 525566749
2023-04-19 15:28:36 -07:00
Jake VanderPlas
a083ba7853 DOC: explicitly mention io_callback in FAQ 2023-04-19 12:30:53 -07:00
Skye Wanderman-Milne
b917a31f56 Update TPU install on main docs page 2023-04-19 17:52:16 +00:00
laqua-stack
d742733bea feat (scipy.special): Add a xla version of scipy.special.gamma function
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs

Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.

Resolves: #15409
2023-04-18 21:10:22 +02:00