18992 Commits

Author SHA1 Message Date
Jake VanderPlas
6569b320b2 CI: bump mypy to version 1.8.0 2024-01-10 10:20:55 -08:00
jax authors
6f5acf0c0d Merge pull request #19289 from jakevdp:update-array-api-tests
PiperOrigin-RevId: 597281920
2024-01-10 10:04:29 -08:00
Jake VanderPlas
635b9b6029 CI: update array API tests pin 2024-01-10 09:27:57 -08:00
jax authors
f31d1e7599 Bump NCCL version on JAX OSS
PiperOrigin-RevId: 597257455
2024-01-10 08:28:15 -08:00
jax authors
adf05d520a Merge pull request #19282 from gnecula:poly_is_symb
PiperOrigin-RevId: 597170362
2024-01-10 01:04:21 -08:00
George Necula
df280a11b0 [shape_poly] Introduce is_symbolic_dim and deprecate is_poly_dim.
The old is_poly_dim seems to be used in a few places externally.
This was from the time when the symbolic dimensions were polynomials,
now we use the more generic term symbolic dimension or expression.

We introduce is_symbolic_dim and we export it through the jax.experimental.export.
We plan to make the entire shape_poly.py module private, and this is
a necessary step.
2024-01-10 10:10:30 +02:00
jax authors
88169cf9e5 Merge pull request #19275 from j2kun:main
PiperOrigin-RevId: 597148779
2024-01-09 23:06:02 -08:00
jax authors
cbb40c101d Update XLA dependency to use revision
dcb94c46b2.

PiperOrigin-RevId: 597109537
2024-01-09 19:34:41 -08:00
Nitin Srinivasan
b58772cdb4 Enable Bazel remote cache in macOS continuous builds
PiperOrigin-RevId: 597100776
2024-01-09 18:30:31 -08:00
Jeremy Kun
2e6e5da49b docs/pallas: remove list from out_specs 2024-01-09 15:47:18 -08:00
jax authors
f74faf9c9c Merge pull request #19279 from jakevdp:fix-jupytext
PiperOrigin-RevId: 597063608
2024-01-09 15:36:08 -08:00
Jake VanderPlas
93500a8477 lint: fix jupytext version 2024-01-09 15:28:17 -08:00
jax authors
77cb0f141c Merge pull request #19147 from 8bitmp3:jax-docs-jaxprs
PiperOrigin-RevId: 597057971
2024-01-09 15:21:43 -08:00
jax authors
38f23b49af Merge pull request #19277 from jakevdp:jupytext-version
PiperOrigin-RevId: 597057949
2024-01-09 15:13:40 -08:00
Jeremy Kun
4ecbed1322 docs/pallas: sync jupyter notebook 2024-01-09 15:00:42 -08:00
Jake VanderPlas
10eae3f93a CI: update jupytext version 2024-01-09 14:34:21 -08:00
Jeremy Kun
87f914e02f docs/pallas: fix mismatched pytree specs
Running the example as-is gives

```
ValueError: Pytree specs for `out_shape` and `out_specs` must match: PyTreeDef(*) vs. PyTreeDef((*,))
```

Giving a list argument to `out_shape` seems to fix the issue.
2024-01-09 13:56:28 -08:00
8bitmp3
66a845e0c6 Upgrade JAX internals 301 jaxpr language tutorial 2024-01-09 21:53:22 +00:00
jax authors
04f2c91399 Merge pull request #19092 from 8bitmp3:jax-docs-gradient-checkpointing
PiperOrigin-RevId: 597034325
2024-01-09 13:50:13 -08:00
Sergei Lebedev
ba10775eda Added a compatibility overlay for Triton Python APIs
Follow up changes will gradually re-implement these APIs using the MLIR
builders added in google/jax#19159.

PiperOrigin-RevId: 597023799
2024-01-09 13:13:56 -08:00
jax authors
226d6ae392 Merge pull request #19272 from jakevdp:pallas-indexing
PiperOrigin-RevId: 597020538
2024-01-09 13:02:57 -08:00
Tomás Longeri
92bcd3f902 [Mosaic] apply_vector_layout: Copy docstring for VectorLayout in Python to C++
This is in preparation for removing the Python version.

PiperOrigin-RevId: 597015430
2024-01-09 12:43:59 -08:00
Jake VanderPlas
be8183d746 pallas: improve indexing trace time 2024-01-09 11:32:00 -08:00
jax authors
e76f514b49 Merge pull request #19270 from jakevdp:permute-dims-sig
PiperOrigin-RevId: 596992011
2024-01-09 11:20:20 -08:00
jax authors
2356e57af4 [Mosaic] Add lowering rule for jax.lax.logical_not
jax.lax.logical_not(x) is lowered into: xor x, 0xffffffff

PiperOrigin-RevId: 596965921
2024-01-09 09:58:43 -08:00
Jake VanderPlas
707657e5b7 Adjust permute_dims signature to match NumPy
This really doesn't matter because it's a position-only argument, but this
change satisfies our tests and is easier than making the tests smarter.
2024-01-09 09:56:19 -08:00
Sergei Lebedev
f219482212 The Triton MLIR bindings now include auto-generated wrappers for enums
PiperOrigin-RevId: 596873541
2024-01-09 03:00:47 -08:00
jax authors
df0f1e06e0 Merge pull request #19258 from gnecula:poly_new_order
PiperOrigin-RevId: 596818710
2024-01-08 23:00:43 -08:00
George Necula
b7f82e8cad [shape_poly] Improve the lexicographic ordering of symbolic expressions
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
2024-01-09 08:50:54 +02:00
jax authors
7ad7890068 Merge pull request #18953 from gnecula:poly_min_max
PiperOrigin-RevId: 596768330
2024-01-08 19:37:00 -08:00
jax authors
27f7035105 Update XLA dependency to use revision
5d539b8bd3.

PiperOrigin-RevId: 596767921
2024-01-08 19:28:48 -08:00
jax authors
6a99e38a82 Merge pull request #19252 from jakevdp:fix-vecdot
PiperOrigin-RevId: 596762791
2024-01-08 18:54:07 -08:00
jax authors
9556b09349 Merge pull request #19253 from jakevdp:fix-bitwise-count
PiperOrigin-RevId: 596756207
2024-01-08 18:20:13 -08:00
jax authors
06152e645a Merge pull request #19248 from jakevdp:atanh-test
PiperOrigin-RevId: 596755300
2024-01-08 18:11:20 -08:00
Jake VanderPlas
0ea2fc6c15 NumPy 2.0: skip complex atanh test 2024-01-08 18:04:36 -08:00
Jake VanderPlas
f901bead48 jnp.linalg.vecdot: fix broadcasting & conjugation semantics 2024-01-08 18:03:44 -08:00
jax authors
8f850f243c Merge pull request #19247 from jakevdp:new-ufuncs
PiperOrigin-RevId: 596742534
2024-01-08 17:11:31 -08:00
jax authors
856915f3c4 Merge pull request #19244 from jakevdp:permute-dims
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
Mark Sandler
773e1499f1 Updates doc for host_local_array_to_global_array to reflect a few use case patterns and adds a few extra tests.
PiperOrigin-RevId: 596739035
2024-01-08 16:51:26 -08:00
8bitmp3
efeb20e380 Upgrade JAX Gradient Checkpointing doc 2024-01-08 23:36:21 +00:00
Jake VanderPlas
60665f90ec jnp.bitwise_count: promote input to numeric 2024-01-08 15:19:54 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
Yash Katariya
dccc0e8e5c Preserve the specs passed by the user in the output sharding from a eager shard_map.
PiperOrigin-RevId: 596665787
2024-01-08 12:09:20 -08:00
Jake VanderPlas
bd4b01ebbd List several NumPy 2.0 ufuncs as unimplemented 2024-01-08 11:12:58 -08:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
jax authors
f33f0e4337 Merge pull request #19220 from jakevdp:dep-kind
PiperOrigin-RevId: 596636214
2024-01-08 10:32:34 -08:00
jax authors
25cf5cfee1 Merge pull request #19245 from jakevdp:complex-warning
PiperOrigin-RevId: 596633005
2024-01-08 10:23:32 -08:00
Jake VanderPlas
6ab9398574 Remove reference to np.ComplexWarning
This is not available in NumPy 2.0.
2024-01-08 09:53:06 -08:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
jax authors
754b74d542 Allow cumulative sums to be int32 in jax2tf
PiperOrigin-RevId: 596604690
2024-01-08 08:40:35 -08:00