518 Commits

Author SHA1 Message Date
Pearu Peterson
82b2ae211c Add CUDA Array Interface consumer support 2024-02-07 12:08:36 +02:00
jax authors
136ab066b3 Merge pull request #19681 from skye:version
PiperOrigin-RevId: 604809179
2024-02-06 17:04:53 -08:00
Skye Wanderman-Milne
b93772fc24 Update version numbers post-0.4.24 release 2024-02-06 16:28:23 -08:00
Jake VanderPlas
35c0f64836 jnp.linalg.solve: deprecate batched 1D solves when b.ndim > 1 2024-02-06 11:37:42 -08:00
George Necula
fdf227e7b2 [export] Set default native serialization version to 9.
This version adds better support for JAX effects.

See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.

PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
George Necula
af306680d0 [shape_poly] Forgot to update CHANGELOG for #19235. 2024-01-23 17:32:59 +01:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
Jake VanderPlas
91a33362de Deprecate jax.lax.tie_in 2024-01-18 13:13:47 -08:00
Jake VanderPlas
03ce8ca0ca jax.random: deprecate passing of batched keys to APIs 2024-01-17 12:53:24 -08:00
jax authors
aac996c4db Merge pull request #19390 from jakevdp:jnp-sign
PiperOrigin-RevId: 599203136
2024-01-17 09:48:17 -08:00
Peter Hawkins
c4368351d2 Add support for bool dlpack values.
PiperOrigin-RevId: 599199196
2024-01-17 09:30:42 -08:00
Jake VanderPlas
fb56224ae0 jnp.sign: use x/abs(x) for complex arguments 2024-01-17 08:59:40 -08:00
Jake VanderPlas
7d6a134f4e logsumexp: use NumPy 2.0 convention for complex sign 2024-01-16 16:15:06 -08:00
Jake VanderPlas
fa6d3f26ff jnp.unique: make return_inverse shape match NumPy 2.0 2024-01-16 11:47:45 -08:00
jax authors
94b2da6a3b Merge pull request #19302 from carlosgmartin:scipy-stats-sem
PiperOrigin-RevId: 598884144
2024-01-16 10:34:45 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
carlosgmartin
18ecd2e4fd Add scipy.stats.sem. 2024-01-13 22:17:21 -05:00
George Necula
3b7917a56e [shape_poly] Improve and rename export.args_specs.
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.

The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
2024-01-12 08:11:03 +02:00
George Necula
6cac99e664 [shape_poly] Deprecate shape_poly.PolySpec.
This class has limited usefulness, and it seems worth
removing it in favor of using strings for polymorphic
specifications, thus reducing the API surface.
2024-01-12 06:30:38 +02:00
Skye Wanderman-Milne
43918b7d87 [jax] Add pybind for PjRtExecutable::GetCostAnalysis/Executable.cost_analysis
This makes cross-compiled `Compiled.cost_analysis` work.

PiperOrigin-RevId: 597411014
2024-01-10 18:51:55 -08:00
George Necula
df280a11b0 [shape_poly] Introduce is_symbolic_dim and deprecate is_poly_dim.
The old is_poly_dim seems to be used in a few places externally.
This was from the time when the symbolic dimensions were polynomials,
now we use the more generic term symbolic dimension or expression.

We introduce is_symbolic_dim and we export it through the jax.experimental.export.
We plan to make the entire shape_poly.py module private, and this is
a necessary step.
2024-01-10 10:10:30 +02:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
3195a069ef [shape_poly] Improved the tests for inequality comparisons.
Added more tests and broke some large tests into smaller ones.
2024-01-08 08:39:28 +02:00
George Necula
cd0e10f29b [shape_poly] Simplify and speed-up the __eq__ functions for symbolic expressions
Equality is used heavily for symbolic expressions because we use them
as dictionary keys or in sets. Previously, we used a more complete
and more expensive form of equality where we attempted to prove that
"e1 - e2 >= 0" and "e1 - e2 <= 0". This is an overkill and none
of the tests we have so far rely on this power. Now we just
normalize "e1 - e2" and if it reduces syntactically to an integer
we check if the integer is 0. If the difference does not reduce
to an integer we say that the expressions are disequal.

This may possibly change user-visible behavior when it depends
on the outcome of equality comparisons of symbolic dimensions
in presence of shape polymorphism.
2024-01-07 13:18:18 +02:00
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
Jake VanderPlas
47e5c81a2c jnp.ndarray.item(): add args support 2024-01-03 13:03:47 -08:00
Jake VanderPlas
c06e186f60 Error on conversion of empty arrays to boolean.
PiperOrigin-RevId: 595264332
2024-01-02 19:26:45 -08:00
Jake VanderPlas
fff5ea579a Remove deprecated unsafe_raw_array method from PRNG keys
PiperOrigin-RevId: 595190146
2024-01-02 13:03:21 -08:00
Jake VanderPlas
cab63114b4 Remove deprecated function jax.numpy.trapz
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 592266215
2023-12-19 09:57:39 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Sergei Lebedev
41531123f4 Rolling back #18980, because it is not backwards compatible and breaks existing users.
Reverts 91faddd023c2df77df310f3f2f17eb2fa1e60df0

PiperOrigin-RevId: 591200403
2023-12-15 03:24:01 -08:00
George Necula
fd0f007765 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1 = export.deserialized(ser)
export.call(exp1)
```

This change also includes a workaround to allow users to still
do `from jax.experimental.export import export`, for a while.
2023-12-15 10:00:05 +02:00
Yash Katariya
8bf3a86860 [roll forward 2] Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
Reverts b52bcc1639368069075284eefc763f824ca155f1

PiperOrigin-RevId: 590959383
2023-12-14 09:14:25 -08:00
Yash Katariya
6e1ab7ca3f Finish release of jax and jaxlib 0.4.23
PiperOrigin-RevId: 590833947
2023-12-13 23:39:08 -08:00
Peter Hawkins
b392622647 Add patch to suppress XLA:GPU logging.
PiperOrigin-RevId: 590780227
2023-12-13 18:53:50 -08:00
Yash Katariya
25c16c0b78 Finish jax and jaxlib 0.4.22 release
PiperOrigin-RevId: 590775311
2023-12-13 18:26:47 -08:00
Yash Katariya
b52bcc1639 Reverts 3c07c10a9a55f9a32390dd10cf3f420bdf3f1ed8
PiperOrigin-RevId: 590700623
2023-12-13 13:45:14 -08:00
Yash Katariya
3c07c10a9a Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
PiperOrigin-RevId: 590684939
2023-12-13 12:48:48 -08:00
Jake VanderPlas
35b84402c0 Deprecate arr.device_buffer and arr.device_buffers 2023-12-06 10:20:29 -08:00
Yash Katariya
a9bfbd32e1 Finish jax and jaxlib 0.4.21 release
PiperOrigin-RevId: 587866580
2023-12-04 15:51:58 -08:00
Yash Katariya
f0bc7e0fc6 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e
PiperOrigin-RevId: 587231484
2023-12-01 22:06:33 -08:00
jax authors
8ad774fb10 Automate arguments for jax.distributed.initialize for cloud TPU environments.
PiperOrigin-RevId: 586892544
2023-11-30 22:25:00 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Jake VanderPlas
0aec40a16f Deprecate arr.device_buffer and arr.device_buffers 2023-11-29 15:31:01 -08:00
Jake VanderPlas
13dd5e42cc Deprecate non-array inputs to jnp.array_equal & jnp.array_equiv 2023-11-28 13:55:18 -08:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Peter Hawkins
49c80e68d1 Fix error/hang when non-finite values are passed to non-symmetric Eigendecomposition.
Improve the documentation of lax.eig().

Fixes https://github.com/google/jax/issues/18226

PiperOrigin-RevId: 584170564
2023-11-20 17:32:16 -08:00
jax authors
dce6ab4548 Reverts 2aaa7559f96e4bb7b0271665bf386bf3ba22c451
PiperOrigin-RevId: 584033001
2023-11-20 08:23:40 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00