439 Commits

Author SHA1 Message Date
Peter Hawkins
b7dfde8d87 Add notes about the new CUDA version restrictions to the changelog and installation instructions. 2023-09-27 15:56:47 -04:00
Peter Hawkins
a2e1f1f24e Update changelog.
Bump the minimum CUDA 12 pip package versions to the current releases.
2023-09-26 18:21:51 -04:00
Peter Hawkins
2fd6df45e4 Fix test failures under SciPy 1.11 for scipy.stats.mode. 2023-09-23 20:15:51 +00:00
Jake VanderPlas
243a6a236c dtypes.issubdtype: validate a when b is dtypes.extended 2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f [random] deprecate named key creation functions 2023-09-21 13:57:49 -07:00
Ayaka
74bc42e53e
Fix typo in CHANGELOG.md 2023-09-21 14:37:19 +08:00
Jake VanderPlas
024b1f23d7 Remove deprecated submodule jax.abstract_arrays 2023-09-19 15:40:18 -07:00
Yash Katariya
dcc465b4de Finish jax and jaxlib 0.4.16 release
PiperOrigin-RevId: 566477931
2023-09-18 19:09:19 -07:00
Yash Katariya
a2720ee2c3 Deprecate jax.experimental.pjit.with_sharding_constraint. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00
Roy Frostig
1f8cc44f4e deprecate PRNGKeyArray.unsafe_raw_array in favor of jax.random.key_data
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.

PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Jake VanderPlas
4e6c1b68c7 Deprecate random.KeyArray and random.PRNGKeyArray 2023-09-13 14:05:42 -07:00
Jake VanderPlas
eeb32a7d1f Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped
PiperOrigin-RevId: 565142019
2023-09-13 13:21:46 -07:00
Jake VanderPlas
22ff7bd19a Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
These have been deprecated in JAX following similar deprecations in numpy v1.25.0

PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
408c657436 Add a release note about a fixed Windows crash. 2023-09-07 09:35:25 -04:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
Jake VanderPlas
4b89d03147 Deprecate the contents of jax.prng 2023-08-30 15:13:32 -07:00
Skye Wanderman-Milne
f71ba0a2e7 Update versions and changelog post 0.4.15 release 2023-08-30 16:20:25 -04:00
Peter Hawkins
9be96c1d69 Deprecate a number of exports from jax.interpreters.xla.
Custom HLO lowering rules for primitives should be updated to use MLIR StableHLO lowering rules via jax.interpreter.mlir.

PiperOrigin-RevId: 561215967
2023-08-29 20:47:59 -07:00
Peter Hawkins
93900245aa Remove jax.interpreters.xla.register_collective_primitive.
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.

PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Yash Katariya
6072d5993e Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Peter Hawkins
46ac9e2170 Use the default CSR matmul algorithm.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.

PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Jake VanderPlas
665b176c2c remove deprecated jax.lax.prod function
PiperOrigin-RevId: 559787522
2023-08-24 10:13:59 -07:00
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Jake VanderPlas
19a57e1a01 Deprecate jax.numpy.row_stack 2023-08-22 13:12:49 -07:00
Peter Hawkins
4224a4d129 Deprecate jax.scipy.linalg.tril and jax.scipy.linalg.triu.
The corresponding functions are deprecated in scipy. Use the equivalent jax.numpy functions instead.
2023-08-18 16:14:42 -04:00
George Necula
ad15a38ec1 [host_callback] Remove old backwards compatibility flag jax_host_callback_ad_transforms.
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
2023-08-16 10:01:49 -07:00
George Necula
cf4e1d414b [jax2tf] Bump the default JAX serialization version to 7.
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
2023-08-11 22:49:31 -07:00
Jake VanderPlas
ad8e719b82 Add jnp.ufunc and jnp.frompyfunc 2023-08-10 14:58:18 -07:00
Peter Hawkins
0e80d959c8 Mark jnp.{NINF,NZERO,PZERO} as deprecated.
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357).

PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Skye Wanderman-Milne
3e50fea29e Remove option to use StreamExecutor Cloud TPU client in JAX
It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023).

PiperOrigin-RevId: 554935166
2023-08-08 14:05:27 -07:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Peter Hawkins
afd56c15d9 Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Peter Hawkins
c879f65aa6 [JAX] Remove the non-coordination service distributed service implementation from JAX.
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.

PiperOrigin-RevId: 554608165
2023-08-07 15:17:25 -07:00
George Necula
8d80e2587b [jax2tf] Turn on JAX native serialization by default.
See changes to the README.md for mechanisms to override the default.

PiperOrigin-RevId: 554390866
2023-08-07 01:03:55 -07:00
Patrick Kidger
d6dad3827d Documented the shortening of tracebacks 2023-08-04 12:20:19 -07:00
Skye Wanderman-Milne
011fc88c03 Update versions and changelog for jax 0.4.14 release 2023-07-27 16:22:53 -07:00
jax authors
1ceddfc98a Merge pull request #16710 from gnecula:poly_max0
PiperOrigin-RevId: 549515427
2023-07-19 21:40:17 -07:00
Jake VanderPlas
7205160095 Re-parameterize jax.random.gamma for better behavior at endpoints 2023-07-19 16:15:03 -07:00
jax authors
5ae3ac28cd Add deprecation of jax.stages.Compiled.compiler_ir to the change log
PiperOrigin-RevId: 549415191
2023-07-19 13:48:55 -07:00
George Necula
e643f98558 [shape_poly] Reimplement the shape constraint checking using shape assertions.
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.

As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
2023-07-19 09:56:33 +03:00
Peter Hawkins
59509dc2b3 Remove the jax_array config option, which does nothing.
PiperOrigin-RevId: 548981491
2023-07-18 06:16:06 -07:00
Yash Katariya
f0ce0d8c6a Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
PiperOrigin-RevId: 548673905
2023-07-17 06:35:49 -07:00
George Necula
603eeb1901 Copybara import of the project:
--
06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula <gcnecula@gmail.com>:

[jax2tf] Added a flag and environment variable to control the serialization version.

This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c
PiperOrigin-RevId: 548504243
2023-07-16 09:27:12 -07:00
Peter Hawkins
651f87733b Remove jax_jit_pjit_api_merge.
PiperOrigin-RevId: 548236671
2023-07-14 15:25:00 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
Jake VanderPlas
21f6736005 Remove several deprecated APIs 2023-07-11 12:42:32 -07:00
Jake VanderPlas
b581ad1f33 Remove several deprecated jax.Array methods:
- arr.broadcast
- arr.broadcast_in_dim
- arr.split

These have been deprecated since JAX v0.4.5

PiperOrigin-RevId: 547228974
2023-07-11 10:34:27 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
d0e75ca117 Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00