20573 Commits

Author SHA1 Message Date
jax authors
500da57e91 Merge pull request #21077 from merrymercy:patch-1
PiperOrigin-RevId: 631409738
jax-v0.4.27 jaxlib-v0.4.27 jax-v0.4.27-rc
2024-05-07 07:07:04 -07:00
Yash Katariya
70b4477296 Start jax and jaxlib 0.4.27 release
PiperOrigin-RevId: 631409685
2024-05-07 07:01:24 -07:00
Adam Paszke
326adc01a5 [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args
PiperOrigin-RevId: 631404097
2024-05-07 06:36:36 -07:00
jax authors
3e5a18f1e2 Update XLA dependency to use revision
873d09720f.

PiperOrigin-RevId: 631274530
2024-05-06 20:37:43 -07:00
jax authors
cb0c49850c Merge pull request #21081 from hawkinsp:sourcemap
PiperOrigin-RevId: 631236806
2024-05-06 17:33:12 -07:00
jax authors
4de346485d Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model.
PiperOrigin-RevId: 631235764
2024-05-06 17:28:13 -07:00
jax authors
eee2783e85 Merge pull request #21070 from shuhand0:rel0.0.7
PiperOrigin-RevId: 631218770
2024-05-06 16:22:15 -07:00
jax authors
f6d88525a8 Merge pull request #20327 from selamw1:add_examples
PiperOrigin-RevId: 631186425
2024-05-06 14:30:06 -07:00
Shuhan Ding
aac36799fd
fix jaxlib config name 2024-05-06 13:51:22 -07:00
Selam Waktola
9caf59d68b improve documentation for ix_ 2024-05-06 13:43:55 -07:00
jax authors
3d3cb0bd2c Merge pull request #20842 from Micky774:array-api-default-promotion
PiperOrigin-RevId: 631168892
2024-05-06 13:39:56 -07:00
jax authors
1013b1ad85 Merge pull request #21079 from jakevdp:tensorinv
PiperOrigin-RevId: 631168828
2024-05-06 13:35:03 -07:00
jax authors
bb6aa12fc1 Merge pull request #21087 from jakevdp:upstream-print-version
PiperOrigin-RevId: 631154080
2024-05-06 12:48:29 -07:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
Jake VanderPlas
7d96e78a55 CI: print numpy/scipy version in upstream job 2024-05-06 11:14:38 -07:00
jax authors
a265e42192 Add an experimental, Clang version of the Windows CI job.
Once proven to work, this job will be deleted, and the MSVC job changed to use Clang.

PiperOrigin-RevId: 631122967
2024-05-06 11:10:09 -07:00
Jake VanderPlas
4a363156b9 jnp.linalg tensorinv & tensorsolve: improve implementation & docs 2024-05-06 11:08:36 -07:00
jax authors
d26bd735f6 Merge pull request #21084 from jakevdp:fix-upstream
PiperOrigin-RevId: 631112318
2024-05-06 10:49:34 -07:00
jax authors
7e9ef1e4d2 Merge pull request #21078 from jakevdp:numpy-linalg-doc
PiperOrigin-RevId: 631112228
2024-05-06 10:44:42 -07:00
jax authors
fb65ba4adf Add a config for using Clang on Windows.
PiperOrigin-RevId: 631112031
2024-05-06 10:39:28 -07:00
jax authors
8ba5c64794 Pass bazel_options directly to the Bazel command, instead of into .bazelrc.
PiperOrigin-RevId: 631099970
2024-05-06 10:05:19 -07:00
Jake VanderPlas
40b2d4852e jnp.linalg: improve API documentation 2024-05-06 09:22:59 -07:00
Jake VanderPlas
6f7ebff585 random_lax_test: fix kstest for newer NumPy 2024-05-06 09:20:07 -07:00
Meekail Zain
34c5163fd2 Refactored common upcast for integral-type accumulators 2024-05-06 15:13:10 +00:00
Lianmin Zheng
0eed28a010
Fix a typo in jax.jit docstring 2024-05-06 04:59:23 -07:00
jax authors
7681493760 Don't create temp directory when module is getting imported.
PiperOrigin-RevId: 630958402
2024-05-06 00:58:45 -07:00
jax authors
047ea210e8 Update XLA dependency to use revision
8833ecd870.

PiperOrigin-RevId: 630731298
2024-05-04 19:09:55 -07:00
jax authors
a1c82219e2 Update XLA dependency to use revision
25d66ce58c.

PiperOrigin-RevId: 630557277
2024-05-03 19:19:41 -07:00
jax authors
1b804a7720 Merge pull request #21056 from mattjj:vmap-grad-remat-shmap-bug
PiperOrigin-RevId: 630555588
2024-05-03 19:06:46 -07:00
Matthew Johnson
7a87010f84 [shard_map] better fix for spmd_axis_name issues with shmap residuals
The fix in #21032 was not correct because it assumed that the set of all mesh
axis names appearing in in_specs was an upper bound on the set of mesh axes
over which residuals could be device-varying. But collectives can introduce
device variance! So it's not an upper bound.

We track device variance when check_rep=True, but often people set
check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our
device variance tracking would be limiting. That may be a decent long term
solution, if we can make it easy to annotate pallas_calls with device variance
information. But it's not a great short term one to unblock things.

So instead I temporrarily went with context sensitivity: instead of making
residuals sharded over all mesh.axis_names (as we did before these patches), we
make them sharded over all mesh axis names _excluding_ any spmd_axis_names in
our dynamic context (by looking at the traces in our trace stack). It's illegal
to mention any spmd_axis_names in collectives (indeed anywhere in the body of
the function being vmapped), but I don't think we check it.

TODO(mattjj): add more testing (maybe in follow-ups)
2024-05-04 01:31:15 +00:00
Jake VanderPlas
e95173a4d3 Require arraylike input for several jax.numpy functions
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
jax authors
53208ffe27 Merge pull request #21058 from jakevdp:jnp-delete-doc
PiperOrigin-RevId: 630510340
2024-05-03 15:14:12 -07:00
Jake VanderPlas
88318e60d2 jnp.delete: better docs 2024-05-03 14:41:06 -07:00
jax authors
fc6a5d3a8d Merge pull request #21059 from jakevdp:numpy-doc-tests
PiperOrigin-RevId: 630500410
2024-05-03 14:35:03 -07:00
jax authors
20722248f8 Merge pull request #21064 from jakevdp:doc-new-tutorials
PiperOrigin-RevId: 630484720
2024-05-03 13:39:02 -07:00
Jake VanderPlas
10ed827fe9 DOC: replace old tutorials with new content 2024-05-03 12:20:06 -07:00
Jake VanderPlas
f2c2892c79 Refactor jax.numpy docstring tests 2024-05-03 11:04:43 -07:00
jax authors
7e20e53032 Merge pull request #21057 from jakevdp:scipy-imports
PiperOrigin-RevId: 630435054
2024-05-03 10:49:57 -07:00
Jake VanderPlas
ff67e51e7e Remove last scipy imports 2024-05-03 10:20:05 -07:00
jax authors
e70191bd9e Merge pull request #21055 from shoyer:relu-doc
PiperOrigin-RevId: 630425740
2024-05-03 10:18:49 -07:00
Stephan Hoyer
c77370c492 Recursively pull out __wrapped__ in linkcode_resolve()
This should actually fix the source lookup for `jax.nn.relu`, which
uses both `custom_jvp` and `jit` decorators.
2024-05-03 09:10:00 -07:00
jax authors
c0cfc7ae9f Merge pull request #21032 from mattjj:vmap-grad-shmap-bug
PiperOrigin-RevId: 630375950
2024-05-03 06:47:22 -07:00
jax authors
989ea61697 Merge pull request #21047 from shoyer:linkcode-robust
PiperOrigin-RevId: 630367113
2024-05-03 06:00:40 -07:00
jax authors
d05d29d889 Merge pull request #21050 from rajasekharporeddy:test_branch3
PiperOrigin-RevId: 630339307
2024-05-03 03:24:22 -07:00
rajasekharporeddy
ccabdb29ea Fix typos in docs and an error message 2024-05-03 12:08:22 +05:30
jax authors
52bf27d85c Update XLA dependency to use revision
9fee947f11.

PiperOrigin-RevId: 630247294
2024-05-02 19:26:52 -07:00
Shuhan Ding
28fe45d872
metal_plugin ci with jaxlib nightly 2024-05-02 18:40:01 -07:00
Shuhan Ding
99e5b8e999
enable conv1d test 2024-05-02 18:20:39 -07:00
Stephan Hoyer
f6ee61273a Make linkcode_resolve() a bit more robust
This handles three cases that came up when adopting this snippet (for
finding source code corresponding to API docs) [for neuralgcm](https://github.com/google-research/neuralgcm/pull/58):

1. documenting class method or attributes
2. documenting properties
3. documenting `jit` decorated methods

I'm not sure if case (1) or (2) comes up in the JAX docs, but case (3)
definitely does -- `jit` decorated functions like `jax.nn.relu`
[currently do not](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.relu.html)
have source code link.
2024-05-02 17:57:02 -07:00
Shuhan Ding
0c517f2d83
update along with lax_numpy_test 2024-05-02 17:18:55 -07:00