1530 Commits

Author SHA1 Message Date
8bitmp3
7edc80d635 Update thinking-in-jax working-with-pytrees 2023-12-13 22:57:46 +00:00
8bitmp3
a4003bf0ae Upgrade How to think in JAX 2023-12-13 21:39:17 +00:00
jax authors
0281596ec9 Merge pull request #18869 from 8bitmp3:small-fix-jax-pytrees
PiperOrigin-RevId: 590680938
2023-12-13 12:40:06 -08:00
jax authors
d269c581f0 Merge pull request #18582 from 8bitmp3:jax-docs-debugging-101
PiperOrigin-RevId: 590680437
2023-12-13 12:31:58 -08:00
8bitmp3
7bb4b0fab4 Update working-with-pytrees.md 2023-12-13 00:40:42 +00:00
8bitmp3
14407b9a06 Upgrade JAX debugging doc 2023-12-13 00:19:41 +00:00
Jake VanderPlas
fe2ad89209 array api: add jnp.linalg.cross & jnp.linalg.outer 2023-12-12 11:22:31 -08:00
George Necula
b077483bfa [export] Add support for serialization and deserialization of Exported
At the moment we can export a JAX function into an Exported and we can invoke an Exported from another JAX function, but there is no way to serialize an Exported to be able to use it in another process.

Exported now has all the features we had in mind, so it is a reasonable time to add a serialization method. The intention is for the serialization to have backwards compatibility guarantees, meaning that we can deserialize an Exported that has been serialized with an older version of JAX. This PR does not add explicit support for versioning, nor backwards compatibility tests. Those will follow.

Here we add serialization and deserialization to bytearray, using the flatbuffers package. We use flatbuffers because it is simple, it has backwards and forwards compatibility guarantees, it is a lightweight dependency that does not require additional build steps, and it is fast (deserialization simply indexes into the bytearray rather than creating a Python structure).

In the process of implementing this we have done some small cleanup of the Exported structure:

  * renamed serialization_version to mlir_module_serialization_version
  * renamed disabled_checks to disabled_safety_checks

This code is tested by changing export_test.py to interpose a serialization followed by a deserialization every time we export.export.

There is a known bug with the serialization of effects, so I disabled one of the export tests. Will fix in a subsequent PR.

PiperOrigin-RevId: 590078785
2023-12-11 23:23:02 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
jax authors
c2f8e18016 Merge pull request #18862 from superbobry:pp-improvement
PiperOrigin-RevId: 588906413
2023-12-07 14:15:33 -08:00
jax authors
397a44e504 Merge pull request #18864 from 8bitmp3:jax-docs-pytrees-101
PiperOrigin-RevId: 588859487
2023-12-07 11:40:19 -08:00
8bitmp3
5eebb91e1e Upgrade JAX pytrees doc 2023-12-07 19:18:04 +00:00
jax authors
c4239fc05e Merge pull request #18797 from jakevdp:factorial
PiperOrigin-RevId: 588809589
2023-12-07 08:59:57 -08:00
Sergei Lebedev
ea158d3109 Print pjit name= before other params
The jaxpr sometimes gets pretty big, making it hard to see the name.
2023-12-07 16:54:07 +00:00
Neil Girdhar
9f85beb56b Expose PrecisionLike
This is used in client code like:
https://github.com/search?q=repo%3Agoogle%2Fflax%20%20PrecisionLike&type=code
2023-12-06 14:41:22 -05:00
Peter Hawkins
78543f7bb8 Add jax.extend.mlir.
Some users of JAX want to use the MLIR dialects defined in jaxlib. In particular, these need to be used by custom lowering rules. Add a semi-public (jax.extend) API to access these, rather than having them use jax._src.lib.mlir.

PiperOrigin-RevId: 588448489
2023-12-06 09:16:43 -08:00
jax authors
dc469c3851 Merge pull request #18676 from 8bitmp3:jax-docs-autodiff-101-201
PiperOrigin-RevId: 588172128
2023-12-05 13:29:10 -08:00
8bitmp3
46bad2bd62 Upgrade JAX Autodiff 101 2023-12-05 00:00:13 +00:00
8bitmp3
b1a8bc6a83 Upgrade JAX Installation doc 2023-12-04 21:42:13 +00:00
Jake VanderPlas
70d0f60ce1 Add special.factorial function 2023-12-04 06:13:14 -08:00
Jake VanderPlas
d77cd9a0f4 Add jax.numpy.astype function 2023-11-30 15:50:22 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
jax authors
8020e7d535 Merge pull request #18614 from mattjj:custom-gradient-in-readthedocs
PiperOrigin-RevId: 586074376
2023-11-28 12:47:03 -08:00
jax authors
c855bb0371 Merge pull request #18660 from mmarcinmichal:detached
PiperOrigin-RevId: 586032984
2023-11-28 10:43:14 -08:00
Jake VanderPlas
83cb3369d2 JAX tutorials: quickstart 2023-11-28 08:28:18 -08:00
Marcin Mirończuk
b9f70211b2 Squashed commit - documentation matplotlib
Add information to documentation about that some test targets required matplotlib - squashed commit

Add information to documentation about that some test targets, like a `//tests:logpcg_tests` optionally use matplotlib, so the user may need to `pip install matplotlib` to run tests via bazel (https://github.com/google/jax/pull/18660)

Squashed commit - documentation matplotlib
2023-11-28 11:36:50 +01:00
jax authors
025aa0dcfb Correcting typo in docs, 06-parallelism.
PiperOrigin-RevId: 585889672
2023-11-28 01:34:49 -08:00
jax authors
961ba3cd42 Merge pull request #18631 from jakevdp:tutorials-random
PiperOrigin-RevId: 584612036
2023-11-22 06:42:14 -08:00
Jake VanderPlas
52adb1c6a5 JAX tutorials: add intro to jit compilation 2023-11-21 14:27:44 -08:00
Jake VanderPlas
cacadf43c0 JAX tutorials: pseudorandom numbers 2023-11-21 14:26:06 -08:00
Jake VanderPlas
29a2e8a362 JAX tutorials: add automatic vectorization 2023-11-21 10:39:02 -08:00
jax authors
29eec05c92 Merge pull request #18615 from hawkinsp:docsfix
PiperOrigin-RevId: 584168352
2023-11-20 17:24:38 -08:00
Jake VanderPlas
66a3eff28f JAX tutorials: external callbacks 2023-11-20 17:04:05 -08:00
Peter Hawkins
be9cafc163 Apply jupytext fix to #18220. 2023-11-20 16:42:59 -05:00
ArthurConmy
8077db5ceb Implement bug fix (squash commit) 2023-11-20 16:37:46 -05:00
Matthew Johnson
808958b69e add jax.custom_gradient to readthedocs 2023-11-20 13:26:52 -08:00
Jake VanderPlas
c0b07ea48c DOC: update package requirements 2023-11-20 11:22:29 -08:00
jax authors
ab9c973031 Merge pull request #18600 from nouiz:doc_compilation_cache
PiperOrigin-RevId: 584068904
2023-11-20 10:43:32 -08:00
Frederic Bastien
72b6c9cf0b Document the compilation cache 2023-11-20 07:03:30 -08:00
jiayaobo
ae2387dc27 add random.binomial
update

update

modify
2023-11-19 14:51:10 +08:00
Jake VanderPlas
7456921055 Move docs/tutorial to docs/tutorials 2023-11-17 10:24:56 -08:00
Jake VanderPlas
271d31c1c8 Add jax.experimental.array_api interface 2023-11-16 14:21:04 -08:00
jax authors
7657a0fb15 Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
2023-11-16 11:15:19 -08:00
Jake VanderPlas
f29ec904f6 CI: fix doc build 2023-11-16 07:59:07 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
jax authors
840b5c5d6d Merge pull request #18499 from renecotyfanboy:hyp1f1_poch
PiperOrigin-RevId: 582765493
2023-11-15 12:25:59 -08:00
jax authors
f2c89a43dc Merge pull request #18527 from carlosgmartin:squareplus
PiperOrigin-RevId: 582735733
2023-11-15 11:14:13 -08:00
jax authors
fd155b4fd7 Merge pull request #17850 from nouiz:regression_doc
PiperOrigin-RevId: 582735679
2023-11-15 11:06:09 -08:00
sdupourque
47ca51f474 implementation of poch and hyp1f1 2023-11-15 20:01:00 +01:00