541 Commits

Author SHA1 Message Date
Yue Sheng
147c363ea6 Deprecate jax.clear_backends.
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.

PiperOrigin-RevId: 616946337
2024-03-18 14:23:18 -07:00
Jake VanderPlas
154403c03d Finalize deprecations of jax.interpreters.ad config & source_info_util
These have been raising a DeprecationWarning since JAX 0.4.19, released 2023 Oct 19. I've left the undefined symbols in place for now, as they will raise an informative AttributeError.

PiperOrigin-RevId: 616931120
2024-03-18 13:33:17 -07:00
Peter Hawkins
ee2631e4da Remove --jax_parallel_functions_output_gda.
PiperOrigin-RevId: 616898032
2024-03-18 11:46:06 -07:00
rajasekharporeddy
e94299c946
Fix Typos in CHANGELOG.md
This PR fixes the typos in Change log documentation
2024-03-12 13:57:07 +05:30
Sergei Lebedev
930aaa5e47 Deprecated the jax.experimental.maps submodule
PiperOrigin-RevId: 614082251
2024-03-08 16:50:52 -08:00
Jake VanderPlas
c2d07a6623 Finalize deprecation of non-array arguments to array_equal/array_equiv 2024-02-29 05:31:37 -08:00
Jake VanderPlas
236275ebe1 Deprecate jax.tree_map for jax v0.4.26
Reverts f4045dceb206be1ea10ee651ccc6151809f2d9f3

PiperOrigin-RevId: 611230367
2024-02-28 14:29:01 -08:00
Yash Katariya
e0fd29082d Finish jax and jaxlib 0.4.25 release
PiperOrigin-RevId: 610413312
2024-02-26 08:19:05 -08:00
Yash Katariya
f4045dceb2 Remove the deprecation of jax.tree_map for the release of 0.4.25
PiperOrigin-RevId: 610014256
2024-02-24 09:30:06 -08:00
jax authors
be002b5f1c Merge pull request #19930 from jakevdp:dep-tree_map
PiperOrigin-RevId: 609508069
2024-02-22 15:01:35 -08:00
Jake VanderPlas
a5abe4568d Mention re-instated xla APIs in the CHANGELOG 2024-02-22 12:19:29 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Sergei Lebedev
0bf8dddace Compile Triton kernels via XLA by default
PiperOrigin-RevId: 609299269
2024-02-22 02:32:26 -08:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
Jake VanderPlas
1fe46aa8be Error for deprecated scalar conversions of non-scalar arrays 2024-02-16 11:26:30 -08:00
Jake VanderPlas
6ffea0ba1f tree_transpose: optionally infer inner_treedef 2024-02-15 12:01:21 -08:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Jake VanderPlas
bbfd4f2c26 jax.numpy: implement scalar boolean indexing 2024-02-09 11:00:00 -08:00
Jake VanderPlas
1b08d220f6 Fix jax 0.4.24 changelog 2024-02-09 09:51:41 -08:00
Yash Katariya
73e3dedf9d Update changelog and install doc to mention cuda 12.3 switch
PiperOrigin-RevId: 605473026
2024-02-08 17:21:56 -08:00
Pearu Peterson
82b2ae211c Add CUDA Array Interface consumer support 2024-02-07 12:08:36 +02:00
jax authors
136ab066b3 Merge pull request #19681 from skye:version
PiperOrigin-RevId: 604809179
2024-02-06 17:04:53 -08:00
Skye Wanderman-Milne
b93772fc24 Update version numbers post-0.4.24 release 2024-02-06 16:28:23 -08:00
Jake VanderPlas
35c0f64836 jnp.linalg.solve: deprecate batched 1D solves when b.ndim > 1 2024-02-06 11:37:42 -08:00
George Necula
fdf227e7b2 [export] Set default native serialization version to 9.
This version adds better support for JAX effects.

See description in CHANGELOG.md and also at
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.

PiperOrigin-RevId: 603579274
2024-02-01 21:56:03 -08:00
George Necula
af306680d0 [shape_poly] Forgot to update CHANGELOG for #19235. 2024-01-23 17:32:59 +01:00
Jake VanderPlas
9b9aa1efaf Finalize a number of deprecations from JAX 0.4.19
PiperOrigin-RevId: 600509530
2024-01-22 11:13:25 -08:00
Jake VanderPlas
91a33362de Deprecate jax.lax.tie_in 2024-01-18 13:13:47 -08:00
Jake VanderPlas
03ce8ca0ca jax.random: deprecate passing of batched keys to APIs 2024-01-17 12:53:24 -08:00
jax authors
aac996c4db Merge pull request #19390 from jakevdp:jnp-sign
PiperOrigin-RevId: 599203136
2024-01-17 09:48:17 -08:00
Peter Hawkins
c4368351d2 Add support for bool dlpack values.
PiperOrigin-RevId: 599199196
2024-01-17 09:30:42 -08:00
Jake VanderPlas
fb56224ae0 jnp.sign: use x/abs(x) for complex arguments 2024-01-17 08:59:40 -08:00
Jake VanderPlas
7d6a134f4e logsumexp: use NumPy 2.0 convention for complex sign 2024-01-16 16:15:06 -08:00
Jake VanderPlas
fa6d3f26ff jnp.unique: make return_inverse shape match NumPy 2.0 2024-01-16 11:47:45 -08:00
jax authors
94b2da6a3b Merge pull request #19302 from carlosgmartin:scipy-stats-sem
PiperOrigin-RevId: 598884144
2024-01-16 10:34:45 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
carlosgmartin
18ecd2e4fd Add scipy.stats.sem. 2024-01-13 22:17:21 -05:00
George Necula
3b7917a56e [shape_poly] Improve and rename export.args_specs.
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.

The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
2024-01-12 08:11:03 +02:00
George Necula
6cac99e664 [shape_poly] Deprecate shape_poly.PolySpec.
This class has limited usefulness, and it seems worth
removing it in favor of using strings for polymorphic
specifications, thus reducing the API surface.
2024-01-12 06:30:38 +02:00
Skye Wanderman-Milne
43918b7d87 [jax] Add pybind for PjRtExecutable::GetCostAnalysis/Executable.cost_analysis
This makes cross-compiled `Compiled.cost_analysis` work.

PiperOrigin-RevId: 597411014
2024-01-10 18:51:55 -08:00
George Necula
df280a11b0 [shape_poly] Introduce is_symbolic_dim and deprecate is_poly_dim.
The old is_poly_dim seems to be used in a few places externally.
This was from the time when the symbolic dimensions were polynomials,
now we use the more generic term symbolic dimension or expression.

We introduce is_symbolic_dim and we export it through the jax.experimental.export.
We plan to make the entire shape_poly.py module private, and this is
a necessary step.
2024-01-10 10:10:30 +02:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
3195a069ef [shape_poly] Improved the tests for inequality comparisons.
Added more tests and broke some large tests into smaller ones.
2024-01-08 08:39:28 +02:00
George Necula
cd0e10f29b [shape_poly] Simplify and speed-up the __eq__ functions for symbolic expressions
Equality is used heavily for symbolic expressions because we use them
as dictionary keys or in sets. Previously, we used a more complete
and more expensive form of equality where we attempted to prove that
"e1 - e2 >= 0" and "e1 - e2 <= 0". This is an overkill and none
of the tests we have so far rely on this power. Now we just
normalize "e1 - e2" and if it reduces syntactically to an integer
we check if the integer is 0. If the difference does not reduce
to an integer we say that the expressions are disequal.

This may possibly change user-visible behavior when it depends
on the outcome of equality comparisons of symbolic dimensions
in presence of shape polymorphism.
2024-01-07 13:18:18 +02:00
Jake VanderPlas
8b62516676 [array api] add stable & descending params to jnp.sort & jnp.argsort 2024-01-04 14:21:25 -08:00
Jake VanderPlas
47e5c81a2c jnp.ndarray.item(): add args support 2024-01-03 13:03:47 -08:00