18867 Commits

Author SHA1 Message Date
Jake VanderPlas
be8183d746 pallas: improve indexing trace time 2024-01-09 11:32:00 -08:00
jax authors
e76f514b49 Merge pull request #19270 from jakevdp:permute-dims-sig
PiperOrigin-RevId: 596992011
2024-01-09 11:20:20 -08:00
jax authors
2356e57af4 [Mosaic] Add lowering rule for jax.lax.logical_not
jax.lax.logical_not(x) is lowered into: xor x, 0xffffffff

PiperOrigin-RevId: 596965921
2024-01-09 09:58:43 -08:00
Jake VanderPlas
707657e5b7 Adjust permute_dims signature to match NumPy
This really doesn't matter because it's a position-only argument, but this
change satisfies our tests and is easier than making the tests smarter.
2024-01-09 09:56:19 -08:00
Sergei Lebedev
f219482212 The Triton MLIR bindings now include auto-generated wrappers for enums
PiperOrigin-RevId: 596873541
2024-01-09 03:00:47 -08:00
jax authors
df0f1e06e0 Merge pull request #19258 from gnecula:poly_new_order
PiperOrigin-RevId: 596818710
2024-01-08 23:00:43 -08:00
George Necula
b7f82e8cad [shape_poly] Improve the lexicographic ordering of symbolic expressions
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
2024-01-09 08:50:54 +02:00
jax authors
7ad7890068 Merge pull request #18953 from gnecula:poly_min_max
PiperOrigin-RevId: 596768330
2024-01-08 19:37:00 -08:00
jax authors
27f7035105 Update XLA dependency to use revision
5d539b8bd3.

PiperOrigin-RevId: 596767921
2024-01-08 19:28:48 -08:00
jax authors
6a99e38a82 Merge pull request #19252 from jakevdp:fix-vecdot
PiperOrigin-RevId: 596762791
2024-01-08 18:54:07 -08:00
jax authors
9556b09349 Merge pull request #19253 from jakevdp:fix-bitwise-count
PiperOrigin-RevId: 596756207
2024-01-08 18:20:13 -08:00
jax authors
06152e645a Merge pull request #19248 from jakevdp:atanh-test
PiperOrigin-RevId: 596755300
2024-01-08 18:11:20 -08:00
Jake VanderPlas
0ea2fc6c15 NumPy 2.0: skip complex atanh test 2024-01-08 18:04:36 -08:00
Jake VanderPlas
f901bead48 jnp.linalg.vecdot: fix broadcasting & conjugation semantics 2024-01-08 18:03:44 -08:00
jax authors
8f850f243c Merge pull request #19247 from jakevdp:new-ufuncs
PiperOrigin-RevId: 596742534
2024-01-08 17:11:31 -08:00
jax authors
856915f3c4 Merge pull request #19244 from jakevdp:permute-dims
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
Mark Sandler
773e1499f1 Updates doc for host_local_array_to_global_array to reflect a few use case patterns and adds a few extra tests.
PiperOrigin-RevId: 596739035
2024-01-08 16:51:26 -08:00
Jake VanderPlas
60665f90ec jnp.bitwise_count: promote input to numeric 2024-01-08 15:19:54 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
Yash Katariya
dccc0e8e5c Preserve the specs passed by the user in the output sharding from a eager shard_map.
PiperOrigin-RevId: 596665787
2024-01-08 12:09:20 -08:00
Jake VanderPlas
bd4b01ebbd List several NumPy 2.0 ufuncs as unimplemented 2024-01-08 11:12:58 -08:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
jax authors
f33f0e4337 Merge pull request #19220 from jakevdp:dep-kind
PiperOrigin-RevId: 596636214
2024-01-08 10:32:34 -08:00
jax authors
25cf5cfee1 Merge pull request #19245 from jakevdp:complex-warning
PiperOrigin-RevId: 596633005
2024-01-08 10:23:32 -08:00
Jake VanderPlas
6ab9398574 Remove reference to np.ComplexWarning
This is not available in NumPy 2.0.
2024-01-08 09:53:06 -08:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
jax authors
754b74d542 Allow cumulative sums to be int32 in jax2tf
PiperOrigin-RevId: 596604690
2024-01-08 08:40:35 -08:00
Adam Paszke
2a60c184af Use grid_mapping instead of mosaic_grid_mapping to query scalar prefetch
mosaic_grid_mapping sometimes extend the scalar prefetch operands with arrays
describing the mesh, throwing off the math in the recent patch.

PiperOrigin-RevId: 596596640
2024-01-08 08:14:21 -08:00
jax authors
c5628d78e5 Merge pull request #19243 from gnecula:fix_ruff
PiperOrigin-RevId: 596595608
2024-01-08 08:05:15 -08:00
George Necula
02acacd999 Fix unused import error, breaks ruff. 2024-01-08 17:55:38 +02:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
ed2a839884 Move export backwards compatibility tests out of jax2tf. Step 3.
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.

PiperOrigin-RevId: 596555577
2024-01-08 04:48:10 -08:00
jax authors
3d608b1850 Update CUDA to 12.3 in JAX/TF/XLA CIs
This is updating CUDA to version 12.3. Related libraries (notably cuDNN) are also getting updated.

PiperOrigin-RevId: 596515360
2024-01-08 01:24:32 -08:00
jax authors
be2d995464 Merge pull request #19238 from gnecula:poly_conv
PiperOrigin-RevId: 596495971
2024-01-07 23:37:02 -08:00
George Necula
45fa5021db [shape_poly] Minor improvements for handling of dilation in convolutions.
Use core.dilate_dim, it is clearer and especially when using
shape polymoprhism it will result in better error messages.

Replace the use of np.maximum with core.max_dim, because the
latter will result in fewer errors in presence of shape polymorphism
(the core.max_dim is deferring the actual computation to shape
refinement time).
2024-01-08 08:56:53 +02:00
jax authors
141df477e1 Merge pull request #19236 from gnecula:poly_tests
PiperOrigin-RevId: 596489887
2024-01-07 22:56:33 -08:00
George Necula
3195a069ef [shape_poly] Improved the tests for inequality comparisons.
Added more tests and broke some large tests into smaller ones.
2024-01-08 08:39:28 +02:00
jax authors
51b71e2b39 Merge pull request #19230 from gnecula:poly_deprecated_msg
PiperOrigin-RevId: 596465851
2024-01-07 20:34:03 -08:00
Tomás Longeri
88542f0e56 [Mosaic] Run C++ passes from within custom_call_emitter.cc
PiperOrigin-RevId: 596464480
2024-01-07 20:25:45 -08:00
jax authors
020988c2b2 Merge pull request #19235 from gnecula:poly_clean_err
PiperOrigin-RevId: 596464100
2024-01-07 20:17:20 -08:00
George Necula
2b4177d35f [shape_poly] Improve error messages for shape polymorphism
Since all the inequality comparisons are done via the
`ge` method, until now all the error messages were about
greater or equal. Now we specify the actual original comparison.
2024-01-08 05:52:06 +02:00
jax authors
0776060a0b Update XLA dependency to use revision
b8acf469ed.

PiperOrigin-RevId: 596457575
2024-01-07 19:44:47 -08:00
jax authors
4998c80bcd Merge pull request #19231 from gnecula:poly_eq
PiperOrigin-RevId: 596400074
2024-01-07 10:24:36 -08:00
George Necula
cd0e10f29b [shape_poly] Simplify and speed-up the __eq__ functions for symbolic expressions
Equality is used heavily for symbolic expressions because we use them
as dictionary keys or in sets. Previously, we used a more complete
and more expensive form of equality where we attempted to prove that
"e1 - e2 >= 0" and "e1 - e2 <= 0". This is an overkill and none
of the tests we have so far rely on this power. Now we just
normalize "e1 - e2" and if it reduces syntactically to an integer
we check if the integer is 0. If the difference does not reduce
to an integer we say that the expressions are disequal.

This may possibly change user-visible behavior when it depends
on the outcome of equality comparisons of symbolic dimensions
in presence of shape polymorphism.
2024-01-07 13:18:18 +02:00
jax authors
16699e4e78 Merge pull request #19224 from jakevdp:batched-solve
PiperOrigin-RevId: 596313955
2024-01-06 21:23:13 -08:00
jax authors
d4f8a9bad5 Merge pull request #19221 from jakevdp:array-api-solve
PiperOrigin-RevId: 596312089
2024-01-06 21:13:33 -08:00
George Necula
cea77f5d17 Improve some deprecation error messages 2024-01-07 07:09:39 +02:00
jax authors
d5feeecdaa Merge pull request #19228 from gnecula:poly_hash_cache
PiperOrigin-RevId: 596311540
2024-01-06 21:04:15 -08:00
George Necula
2ca7b31388 [shape_poly] Cache the hash values for symbolic expressions
Hashing is performance critical for symbolic expressions because
their internal representation is as dictionaries. In some of the
unit tests by caching the hash we save 20% of the calls. Hashing
will be even more important for the upcoming decision procedure
for symbolic expressions.

We also cache the sorting of the monomials in a symbolic expression.
2024-01-07 06:50:18 +02:00
jax authors
50edcd57c0 Merge pull request #19227 from gnecula:poly_ordering
PiperOrigin-RevId: 596307071
2024-01-06 20:37:35 -08:00