Jake VanderPlas
6569b320b2
CI: bump mypy to version 1.8.0
2024-01-10 10:20:55 -08:00
jax authors
6f5acf0c0d
Merge pull request #19289 from jakevdp:update-array-api-tests
...
PiperOrigin-RevId: 597281920
2024-01-10 10:04:29 -08:00
Jake VanderPlas
635b9b6029
CI: update array API tests pin
2024-01-10 09:27:57 -08:00
jax authors
f31d1e7599
Bump NCCL version on JAX OSS
...
PiperOrigin-RevId: 597257455
2024-01-10 08:28:15 -08:00
jax authors
adf05d520a
Merge pull request #19282 from gnecula:poly_is_symb
...
PiperOrigin-RevId: 597170362
2024-01-10 01:04:21 -08:00
George Necula
df280a11b0
[shape_poly] Introduce is_symbolic_dim and deprecate is_poly_dim.
...
The old is_poly_dim seems to be used in a few places externally.
This was from the time when the symbolic dimensions were polynomials,
now we use the more generic term symbolic dimension or expression.
We introduce is_symbolic_dim and we export it through the jax.experimental.export.
We plan to make the entire shape_poly.py module private, and this is
a necessary step.
2024-01-10 10:10:30 +02:00
jax authors
88169cf9e5
Merge pull request #19275 from j2kun:main
...
PiperOrigin-RevId: 597148779
2024-01-09 23:06:02 -08:00
jax authors
cbb40c101d
Update XLA dependency to use revision
...
dcb94c46b2
.
PiperOrigin-RevId: 597109537
2024-01-09 19:34:41 -08:00
Nitin Srinivasan
b58772cdb4
Enable Bazel remote cache in macOS continuous builds
...
PiperOrigin-RevId: 597100776
2024-01-09 18:30:31 -08:00
Jeremy Kun
2e6e5da49b
docs/pallas: remove list from out_specs
2024-01-09 15:47:18 -08:00
jax authors
f74faf9c9c
Merge pull request #19279 from jakevdp:fix-jupytext
...
PiperOrigin-RevId: 597063608
2024-01-09 15:36:08 -08:00
Jake VanderPlas
93500a8477
lint: fix jupytext version
2024-01-09 15:28:17 -08:00
jax authors
77cb0f141c
Merge pull request #19147 from 8bitmp3:jax-docs-jaxprs
...
PiperOrigin-RevId: 597057971
2024-01-09 15:21:43 -08:00
jax authors
38f23b49af
Merge pull request #19277 from jakevdp:jupytext-version
...
PiperOrigin-RevId: 597057949
2024-01-09 15:13:40 -08:00
Jeremy Kun
4ecbed1322
docs/pallas: sync jupyter notebook
2024-01-09 15:00:42 -08:00
Jake VanderPlas
10eae3f93a
CI: update jupytext version
2024-01-09 14:34:21 -08:00
Jeremy Kun
87f914e02f
docs/pallas: fix mismatched pytree specs
...
Running the example as-is gives
```
ValueError: Pytree specs for `out_shape` and `out_specs` must match: PyTreeDef(*) vs. PyTreeDef((*,))
```
Giving a list argument to `out_shape` seems to fix the issue.
2024-01-09 13:56:28 -08:00
8bitmp3
66a845e0c6
Upgrade JAX internals 301 jaxpr language tutorial
2024-01-09 21:53:22 +00:00
jax authors
04f2c91399
Merge pull request #19092 from 8bitmp3:jax-docs-gradient-checkpointing
...
PiperOrigin-RevId: 597034325
2024-01-09 13:50:13 -08:00
Sergei Lebedev
ba10775eda
Added a compatibility overlay for Triton Python APIs
...
Follow up changes will gradually re-implement these APIs using the MLIR
builders added in google/jax#19159 .
PiperOrigin-RevId: 597023799
2024-01-09 13:13:56 -08:00
jax authors
226d6ae392
Merge pull request #19272 from jakevdp:pallas-indexing
...
PiperOrigin-RevId: 597020538
2024-01-09 13:02:57 -08:00
Tomás Longeri
92bcd3f902
[Mosaic] apply_vector_layout: Copy docstring for VectorLayout in Python to C++
...
This is in preparation for removing the Python version.
PiperOrigin-RevId: 597015430
2024-01-09 12:43:59 -08:00
Jake VanderPlas
be8183d746
pallas: improve indexing trace time
2024-01-09 11:32:00 -08:00
jax authors
e76f514b49
Merge pull request #19270 from jakevdp:permute-dims-sig
...
PiperOrigin-RevId: 596992011
2024-01-09 11:20:20 -08:00
jax authors
2356e57af4
[Mosaic] Add lowering rule for jax.lax.logical_not
...
jax.lax.logical_not(x) is lowered into: xor x, 0xffffffff
PiperOrigin-RevId: 596965921
2024-01-09 09:58:43 -08:00
Jake VanderPlas
707657e5b7
Adjust permute_dims signature to match NumPy
...
This really doesn't matter because it's a position-only argument, but this
change satisfies our tests and is easier than making the tests smarter.
2024-01-09 09:56:19 -08:00
Sergei Lebedev
f219482212
The Triton MLIR bindings now include auto-generated wrappers for enums
...
PiperOrigin-RevId: 596873541
2024-01-09 03:00:47 -08:00
jax authors
df0f1e06e0
Merge pull request #19258 from gnecula:poly_new_order
...
PiperOrigin-RevId: 596818710
2024-01-08 23:00:43 -08:00
George Necula
b7f82e8cad
[shape_poly] Improve the lexicographic ordering of symbolic expressions
...
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
2024-01-09 08:50:54 +02:00
jax authors
7ad7890068
Merge pull request #18953 from gnecula:poly_min_max
...
PiperOrigin-RevId: 596768330
2024-01-08 19:37:00 -08:00
jax authors
27f7035105
Update XLA dependency to use revision
...
5d539b8bd3
.
PiperOrigin-RevId: 596767921
2024-01-08 19:28:48 -08:00
jax authors
6a99e38a82
Merge pull request #19252 from jakevdp:fix-vecdot
...
PiperOrigin-RevId: 596762791
2024-01-08 18:54:07 -08:00
jax authors
9556b09349
Merge pull request #19253 from jakevdp:fix-bitwise-count
...
PiperOrigin-RevId: 596756207
2024-01-08 18:20:13 -08:00
jax authors
06152e645a
Merge pull request #19248 from jakevdp:atanh-test
...
PiperOrigin-RevId: 596755300
2024-01-08 18:11:20 -08:00
Jake VanderPlas
0ea2fc6c15
NumPy 2.0: skip complex atanh test
2024-01-08 18:04:36 -08:00
Jake VanderPlas
f901bead48
jnp.linalg.vecdot: fix broadcasting & conjugation semantics
2024-01-08 18:03:44 -08:00
jax authors
8f850f243c
Merge pull request #19247 from jakevdp:new-ufuncs
...
PiperOrigin-RevId: 596742534
2024-01-08 17:11:31 -08:00
jax authors
856915f3c4
Merge pull request #19244 from jakevdp:permute-dims
...
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
Mark Sandler
773e1499f1
Updates doc for host_local_array_to_global_array to reflect a few use case patterns and adds a few extra tests.
...
PiperOrigin-RevId: 596739035
2024-01-08 16:51:26 -08:00
8bitmp3
efeb20e380
Upgrade JAX Gradient Checkpointing doc
2024-01-08 23:36:21 +00:00
Jake VanderPlas
60665f90ec
jnp.bitwise_count: promote input to numeric
2024-01-08 15:19:54 -08:00
jax authors
da96633f11
Correct the cache miss metric instrumentation due to the new min cache entry size flag
...
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.
PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
Yash Katariya
dccc0e8e5c
Preserve the spec
s passed by the user in the output sharding from a eager shard_map.
...
PiperOrigin-RevId: 596665787
2024-01-08 12:09:20 -08:00
Jake VanderPlas
bd4b01ebbd
List several NumPy 2.0 ufuncs as unimplemented
2024-01-08 11:12:58 -08:00
George Necula
6b7b3a3902
[shape_poly] Replace non_negative_dim with max_dim and min_dim.
...
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.
One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
jax authors
f33f0e4337
Merge pull request #19220 from jakevdp:dep-kind
...
PiperOrigin-RevId: 596636214
2024-01-08 10:32:34 -08:00
jax authors
25cf5cfee1
Merge pull request #19245 from jakevdp:complex-warning
...
PiperOrigin-RevId: 596633005
2024-01-08 10:23:32 -08:00
Jake VanderPlas
6ab9398574
Remove reference to np.ComplexWarning
...
This is not available in NumPy 2.0.
2024-01-08 09:53:06 -08:00
Jake VanderPlas
d673b9bf5c
[array api] add jax.numpy.permute_dims function
2024-01-08 09:30:51 -08:00
jax authors
754b74d542
Allow cumulative sums to be int32 in jax2tf
...
PiperOrigin-RevId: 596604690
2024-01-08 08:40:35 -08:00