6911 Commits

Author SHA1 Message Date
ilay menahem
390e90361a Add .hypothesis/ directory to .gitignore
and ppf and cdf to scipy.stats.uniform
2024-01-16 18:59:52 +00:00
jax authors
c0d51e7dde Merge pull request #19381 from jakevdp:fix-diff
PiperOrigin-RevId: 598884533
2024-01-16 10:35:05 -08:00
jax authors
94b2da6a3b Merge pull request #19302 from carlosgmartin:scipy-stats-sem
PiperOrigin-RevId: 598884144
2024-01-16 10:34:45 -08:00
Jake VanderPlas
17f5658db8 jnp.diff: support scalar prepend/append 2024-01-16 08:46:44 -08:00
Matthew Johnson
30c0fc4c5f [shard-map] add approx_top_k replication rule 2024-01-16 03:59:40 -08:00
George Necula
a1286d0021 [shape_poly] Improve core.max_dim and core.min_dim
Previously, we optimized `core.max_dim(a, b)` to `a`
if `a >= b` and to `b` if `a < b`. Now we also optimize
it to `b` if `a <= b`.

Similarly for `core.min_dim`.
At the same time we move more of the logic from `core.py`
to `shape_poly.py`.
2024-01-15 15:10:28 +02:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Peter Hawkins
912a5ef771 Disable some tests that time out or OOM under ASAN.
PiperOrigin-RevId: 598543036
2024-01-15 01:40:44 -08:00
carlosgmartin
18ecd2e4fd Add scipy.stats.sem. 2024-01-13 22:17:21 -05:00
George Necula
0967a797e8 [shape_poly] Protect shape_poly: rename to _shape_poly.py.
The public APIs can be accessed through `jax.experimental.export`.

The shape_poly and serialization modules are still changing and I saw
external references to various symbols in them, even protected ones.
I have removed such references from the Google code base, and I want to take
another step to discourage direct access to its symbols.

PiperOrigin-RevId: 598119703
2024-01-13 02:05:11 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
jax authors
1ab5a3579a Merge pull request #19348 from jakevdp:scipy-version
PiperOrigin-RevId: 597978339
2024-01-12 15:48:20 -08:00
Jake VanderPlas
1870eee062 Test: make scipy version parsing compatible with pre-releases 2024-01-12 14:35:28 -08:00
Jake VanderPlas
989618c5f7 [array api] add jax.numpy.concat 2024-01-12 13:12:09 -08:00
Peter Hawkins
d935213a0f [Pallas] Disable sanitizer builds of splash_attention_kernel_test.
These time out in CI.

PiperOrigin-RevId: 597845558
2024-01-12 08:29:05 -08:00
Peter Hawkins
2980e8f09c [JAX] Disable some CUDA tests that fail under ASAN, due to bugs in NCCL and Triton.
PiperOrigin-RevId: 597845384
2024-01-12 08:20:36 -08:00
George Necula
3b7917a56e [shape_poly] Improve and rename export.args_specs.
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.

The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
2024-01-12 08:11:03 +02:00
jax authors
761cf8ba7d Merge pull request #19080 from Zantares:tenglu/fix_ut
PiperOrigin-RevId: 597723598
2024-01-11 20:26:31 -08:00
jax authors
8a8cd6d01a Merge pull request #19321 from jakevdp:diagonal
PiperOrigin-RevId: 597660818
2024-01-11 14:55:50 -08:00
Sharad Vikram
598b46aab5 [Pallas/TPU] Open source "Splash Attention" (Sparse Flash Attention), a general purpose attention kernel where you can specify an attention mask using NumPy.
PiperOrigin-RevId: 597658315
2024-01-11 14:47:39 -08:00
Jake VanderPlas
b08a010949 [array API] add jnp.linalg.diagonal 2024-01-11 12:52:15 -08:00
Peter Hawkins
35fc2ed8e0 Disable ASAN for several CUDA tests.
PiperOrigin-RevId: 597596726
2024-01-11 10:43:38 -08:00
Jake VanderPlas
1a39d8fdb2 [array API] implement jnp.pow; alias for jnp.power 2024-01-10 14:59:46 -08:00
Jake VanderPlas
4e55086dfb array api: add jnp.bitwise_* aliases 2024-01-10 14:22:20 -08:00
Jake VanderPlas
9890b23b0a Add jnp.vecdot 2024-01-10 13:11:37 -08:00
Eugene Zhulenev
ba4c2b1c7d [pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
George Necula
b7f82e8cad [shape_poly] Improve the lexicographic ordering of symbolic expressions
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
2024-01-09 08:50:54 +02:00
jax authors
7ad7890068 Merge pull request #18953 from gnecula:poly_min_max
PiperOrigin-RevId: 596768330
2024-01-08 19:37:00 -08:00
jax authors
6a99e38a82 Merge pull request #19252 from jakevdp:fix-vecdot
PiperOrigin-RevId: 596762791
2024-01-08 18:54:07 -08:00
Jake VanderPlas
0ea2fc6c15 NumPy 2.0: skip complex atanh test 2024-01-08 18:04:36 -08:00
Jake VanderPlas
f901bead48 jnp.linalg.vecdot: fix broadcasting & conjugation semantics 2024-01-08 18:03:44 -08:00
jax authors
8f850f243c Merge pull request #19247 from jakevdp:new-ufuncs
PiperOrigin-RevId: 596742534
2024-01-08 17:11:31 -08:00
jax authors
856915f3c4 Merge pull request #19244 from jakevdp:permute-dims
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
Yash Katariya
dccc0e8e5c Preserve the specs passed by the user in the output sharding from a eager shard_map.
PiperOrigin-RevId: 596665787
2024-01-08 12:09:20 -08:00
Jake VanderPlas
bd4b01ebbd List several NumPy 2.0 ufuncs as unimplemented 2024-01-08 11:12:58 -08:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
George Necula
02acacd999 Fix unused import error, breaks ruff. 2024-01-08 17:55:38 +02:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
ed2a839884 Move export backwards compatibility tests out of jax2tf. Step 3.
The last part of moving the tests: move jax2tf/tests/back_compat_test.py to tests/export_back_compat_test.py.

PiperOrigin-RevId: 596555577
2024-01-08 04:48:10 -08:00
George Necula
3195a069ef [shape_poly] Improved the tests for inequality comparisons.
Added more tests and broke some large tests into smaller ones.
2024-01-08 08:39:28 +02:00
George Necula
2b4177d35f [shape_poly] Improve error messages for shape polymorphism
Since all the inequality comparisons are done via the
`ge` method, until now all the error messages were about
greater or equal. Now we specify the actual original comparison.
2024-01-08 05:52:06 +02:00
jax authors
16699e4e78 Merge pull request #19224 from jakevdp:batched-solve
PiperOrigin-RevId: 596313955
2024-01-06 21:23:13 -08:00
George Necula
2ca7b31388 [shape_poly] Cache the hash values for symbolic expressions
Hashing is performance critical for symbolic expressions because
their internal representation is as dictionaries. In some of the
unit tests by caching the hash we save 20% of the calls. Hashing
will be even more important for the upcoming decision procedure
for symbolic expressions.

We also cache the sorting of the monomials in a symbolic expression.
2024-01-07 06:50:18 +02:00
George Necula
f1c87e0176 [shape_poly] Fixes for the lexicographic ordering of monomials.
There were several bugs in the ordering of atoms and
monomials. The ordering for atoms and moomials is used
for sorting, and the __eq__ is also used for hashing.

One bug was that the ordering of atoms sometimes used
the `id` ordering. Another (performance) bug was that
the __eq__ for atoms used the (semantic) __eq__ for
DimExpr. The latter is expensive to compute, but for
sorting all we need is a syntactic comparison.

We introduce a `_syntactic_cmp` method for atoms,
monomials and expressions and we use it exclusively
for the ordering of atoms and monomials.

We also clean up printing and add tests for ordering and
pretty printing. Now we print monomial in "decreasing" order.
This is a change from before, in the sense that "a + b" is
printed as "b + a".
2024-01-07 06:21:55 +02:00
Jake VanderPlas
20a1972d27 jnp.linalg.solve: fully implement batched cases 2024-01-05 15:21:49 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
jax authors
6dd5a69a18 Merge pull request #19201 from jakevdp:sort-kwargs
PiperOrigin-RevId: 595818057
2024-01-04 15:34:15 -08:00
jax authors
ea66029731 Introduce min entry size check for compilation cache.
Currently, the persistent compilation cache has a time
threshold: the entry is cached only if the compilation
time is less than the threshold. If compilation happens
to take a while, but the resulting executable is small,
there is nothing that prevents caching. This can result
in a large number of small files in the cache.

Introduce a size threshold. If the resulting executable's
size (after serialization and compression) is less than
this threshold, don't cache. This check is in addition to
the compilation time check described above.

Testing: new unit test, test workload.
PiperOrigin-RevId: 595815611
2024-01-04 15:17:05 -08:00