Jake VanderPlas
7bb8312f82
CI: update jupytext to v0.14.7
2023-07-24 11:51:45 -07:00
Jake VanderPlas
7d7a536b55
custom prng: introduce mechanism to identify key arrays by dtype
2023-07-21 12:27:32 -07:00
Peter Hawkins
319ab98980
Apply pyupgrade --py39-plus.
...
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Roy Frostig
9150b239ff
add jax.prng
to uncovered modules list in API policy
2023-07-18 14:13:25 -07:00
Roy Frostig
9aa5307e2f
API compatibility policy: expand on numerics and randomness
2023-07-18 14:13:25 -07:00
Peter Hawkins
f540ae4338
Fix warning about direct invocation of setup.py during jaxlib build.
...
The jaxlib wheel build currently uses `python setup.py bdist_wheel` to construct the wheel. Change it to use `python -m build -w` instead.
To avoid Python getting confused between the directory named `build` in the bazel tree and the Python `build` module, move `build_wheel.py` into `jaxlib/tools`.
PiperOrigin-RevId: 548133811
2023-07-14 08:31:16 -07:00
TJ
2432aa97a9
Change default memory allocation to 75% instead of 90%
2023-07-13 14:44:38 -07:00
Roy Frostig
1ad0a11897
AOT: better error messages on call signature mismatch
...
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
Roy Frostig
14e38a3f9d
AOT doc: fix lower
/compile
expression in error example
2023-07-10 18:27:06 -07:00
Peter Hawkins
816ba91263
Use lower-case PEP 585 names for types.
...
Issue https://github.com/google/jax/issues/16537
PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
f67acee129
Merge pull request #16430 from jakevdp:bool-error
...
PiperOrigin-RevId: 542951181
2023-06-23 14:00:12 -07:00
Peter Hawkins
bfa113ba60
Remove references to Python 3.8.
...
Remove the old build scripts/Dockerfile, since they are unused and broken.
PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Jake VanderPlas
f1e603e4b3
errors: create TracerBoolConversionError for more targeted debugging tips
2023-06-21 01:41:45 -07:00
Peter Hawkins
34d9f5a9ae
Add a CI presubmit that renders the documentation.
2023-06-20 09:29:25 -04:00
Peter Hawkins
5ce6748e2f
Document how to enable concurrent kernel tracing on GPU.
2023-06-16 13:20:27 -04:00
Jake VanderPlas
d7a19442b6
DOC: fix formatting in FAQ
2023-06-14 03:17:08 -07:00
Tom Hennigan
ed073aa6c9
Add jax.tree_util.tree_leaves_with_path(tree).
...
PiperOrigin-RevId: 539609052
2023-06-12 04:13:37 -07:00
jax authors
8d27f20637
Merge pull request #16246 from chrisflesher:scipy-rotation-v3
...
PiperOrigin-RevId: 538788621
2023-06-08 08:10:58 -07:00
jax authors
6518e4e34c
Merge pull request #16271 from jakevdp:abstract-array-deprecation
...
PiperOrigin-RevId: 538763490
2023-06-08 06:07:22 -07:00
Chris Flesher
5be17ed90c
Added scipy.spatial.transform Rotation and Slerp classes
2023-06-08 07:51:32 -05:00
Jake VanderPlas
3fc70d3d8b
Typo: remove stray tick in jax.numpy docs
2023-06-08 01:15:21 -07:00
Jake VanderPlas
47ae5bddd7
Mark jax.abstract_arrays as deprecated
2023-06-07 23:36:40 -07:00
Jake VanderPlas
3bef6214bb
Deprecate jax.numpy functions alltrue, sometrue, product, cumproduct
2023-06-02 04:10:46 -07:00
jax authors
e99045381d
Update mentioning of DeviceArray
and ShardedDeviceArray
to jax.Array
in the parallelism tutorial
...
`jax.Array` is now a unified type for all kinds of arrays.
PiperOrigin-RevId: 537155869
2023-06-01 16:12:59 -07:00
jax authors
ae78de1a49
Merge pull request #16189 from skye:profiling_docs
...
PiperOrigin-RevId: 537046864
2023-06-01 09:34:46 -07:00
Jake VanderPlas
e5cd69479b
DOC: fix doc formatting
2023-06-01 03:37:21 -07:00
ivyzheng
6bf1cbc667
Add key path related guide & code to the documentation.
2023-05-31 20:15:56 -07:00
Skye Wanderman-Milne
1d1429fe8b
Update profiling docs.
...
* Mention that Tensorboard profiling supports device memory usage
* Recommend TB profiling instead of the pprof-based device memory profiling
* Minor updates to GCP instructions
Inspired by https://github.com/google/jax/issues/1491
2023-05-30 14:27:11 -07:00
Jake VanderPlas
333ff4abbc
Add jnp.matrix_transpose() and jax.Array.mT
...
This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/ ). It's lightweight enough that there's hardly any downside to supporting it in JAX.
2023-05-25 09:02:05 -07:00
Kevin Millikin
921fd222bf
Refer to the original map
/zip
classes via builtins
...
Referring to them as simply `map` or `zip` will create recursive
reimplementations (with no base case!) if the cell is reevaluated in
the same runtime.
2023-05-24 07:47:50 +01:00
jax authors
85fb48a33c
Merge pull request #15930 from canyon289:jax201
...
PiperOrigin-RevId: 534149169
2023-05-22 12:34:48 -07:00
Ravin Kumar
473fa7d670
Add building on JAX
2023-05-22 10:05:39 -07:00
jax authors
bb775c7ce1
Merge pull request #15871 from nouiz:doc
...
PiperOrigin-RevId: 533434343
2023-05-19 06:08:01 -07:00
Roy Frostig
ca008f37e3
initiate jax.extend
via docs and top-level module set-up
2023-05-15 15:47:06 -07:00
Roy Frostig
ce840a9cd8
JEP: jax.extend
, a module for extensions
2023-05-05 13:50:22 -07:00
Frederic Bastien
de57b4fd36
Fix a sphinx error.
2023-05-05 11:08:18 -07:00
Frederic Bastien
c1b532eda8
Remove one fct from doc per review.
2023-05-05 11:08:17 -07:00
Frederic Bastien
decdbfb166
Document jax.experimental.multihost_utils
2023-05-05 11:08:17 -07:00
jax authors
5d143e6eea
Merge pull request #15818 from froystig:random-bits-direct
...
PiperOrigin-RevId: 529090390
2023-05-03 07:56:17 -07:00
Roy Frostig
ea3389205f
add jax.random.bits
2023-05-03 06:10:05 -07:00
David Pizzuto
6948d32d15
contributing: Switch repo URL to HTTPS for consistency with other github URLs.
2023-05-01 10:03:39 -07:00
Jake VanderPlas
e059e3b52f
DOC: document jax.experimental.sparse.linalg
2023-04-28 14:18:50 -07:00
Jake VanderPlas
8dc06ed2ce
Document jax.lax.with_sharding_constraint
2023-04-26 10:19:04 -07:00
jax authors
70ebdb0502
Jax101 - Clarify that the compiled code is executed on first call
...
The current wording implies that the first time a jitted function
is called, the computation happens in Python. It's actually only
the tracing that happens in Python, and the compiled code is run
during the first call. The distinction is important e.g., to
understand why it might make sense to jit a function that's only
called once.
PiperOrigin-RevId: 526906176
2023-04-25 02:41:12 -07:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
jax authors
975e76ef76
Merge pull request #15664 from skye:tpu_install
...
PiperOrigin-RevId: 525605301
2023-04-19 18:18:32 -07:00
jax authors
1de4d14da8
Merge pull request #15656 from laqua-stack:add-special-gamma-fcn
...
PiperOrigin-RevId: 525566749
2023-04-19 15:28:36 -07:00
Jake VanderPlas
a083ba7853
DOC: explicitly mention io_callback in FAQ
2023-04-19 12:30:53 -07:00
Skye Wanderman-Milne
b917a31f56
Update TPU install on main docs page
2023-04-19 17:52:16 +00:00
laqua-stack
d742733bea
feat (scipy.special): Add a xla version of scipy.special.gamma function
...
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs
Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.
Resolves : #15409
2023-04-18 21:10:22 +02:00