451 Commits

Author SHA1 Message Date
Peter Hawkins
d856ecc6fb Set RPATH, not RUNPATH in JAX CUDA builds.
Fixes https://github.com/google/jax/issues/17497
2023-10-12 09:38:10 -07:00
Jake VanderPlas
117f4bdf9b Define jax.typing.DTypeLike 2023-10-10 08:46:36 -07:00
Peter Hawkins
4611d13c07 Only perform compilation cache writes from process 0.
This avoids problems with contending writes on filesystems such as GCS.

PiperOrigin-RevId: 572032482
2023-10-09 13:55:07 -07:00
Skye Wanderman-Milne
a06beaa1a2 Update versions post jax 0.4.18 release 2023-10-06 17:20:34 -07:00
Skye Wanderman-Milne
d4a1bb9292 Update setup.py and CHANGELOG for jax 0.4.18 release 2023-10-06 13:13:33 -07:00
jax authors
f0e4ea23cc Merge pull request #17987 from jakevdp:lax-dep
PiperOrigin-RevId: 571401660
2023-10-06 12:23:20 -07:00
Peter Hawkins
8c4d020db9 Improve CUDA install documentation.
Mention NCCL as a dependency, since it will be required by the next jaxlib release.
Mention LD_LIBRARY_PATH and PATH as how one overrides the CUDA installation for local installs.

Fixes #17831
2023-10-06 14:36:29 -04:00
Jake VanderPlas
ce6a0c43ad jax.lax: deprecate inadvertent exports & internal utilities 2023-10-06 11:26:03 -07:00
Peter Hawkins
efc18e4147 [JAX] Obtain NCCL via a stub, rather than linking it statically or dynamically.
This shrinks the CUDA jaxlib wheel size by around 80MB.

PiperOrigin-RevId: 570554454
2023-10-03 18:33:58 -07:00
Skye Wanderman-Milne
82b58386b7 Update versions and CHANGELOG after jax 0.4.17 release 2023-10-03 17:54:35 -07:00
Jake VanderPlas
a09fdf6e2f Add jax.numpy.bitwise_count() 2023-10-03 13:48:16 -07:00
Jake VanderPlas
9247a62b2b Add CHANGELOG entry for the jnp annotation change 2023-10-02 11:31:28 -07:00
Peter Hawkins
b7dfde8d87 Add notes about the new CUDA version restrictions to the changelog and installation instructions. 2023-09-27 15:56:47 -04:00
Peter Hawkins
a2e1f1f24e Update changelog.
Bump the minimum CUDA 12 pip package versions to the current releases.
2023-09-26 18:21:51 -04:00
Peter Hawkins
2fd6df45e4 Fix test failures under SciPy 1.11 for scipy.stats.mode. 2023-09-23 20:15:51 +00:00
Jake VanderPlas
243a6a236c dtypes.issubdtype: validate a when b is dtypes.extended 2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f [random] deprecate named key creation functions 2023-09-21 13:57:49 -07:00
Ayaka
74bc42e53e
Fix typo in CHANGELOG.md 2023-09-21 14:37:19 +08:00
Jake VanderPlas
024b1f23d7 Remove deprecated submodule jax.abstract_arrays 2023-09-19 15:40:18 -07:00
Yash Katariya
dcc465b4de Finish jax and jaxlib 0.4.16 release
PiperOrigin-RevId: 566477931
2023-09-18 19:09:19 -07:00
Yash Katariya
a2720ee2c3 Deprecate jax.experimental.pjit.with_sharding_constraint. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00
Roy Frostig
1f8cc44f4e deprecate PRNGKeyArray.unsafe_raw_array in favor of jax.random.key_data
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.

PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Jake VanderPlas
4e6c1b68c7 Deprecate random.KeyArray and random.PRNGKeyArray 2023-09-13 14:05:42 -07:00
Jake VanderPlas
eeb32a7d1f Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped
PiperOrigin-RevId: 565142019
2023-09-13 13:21:46 -07:00
Jake VanderPlas
22ff7bd19a Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
These have been deprecated in JAX following similar deprecations in numpy v1.25.0

PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
408c657436 Add a release note about a fixed Windows crash. 2023-09-07 09:35:25 -04:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
Jake VanderPlas
4b89d03147 Deprecate the contents of jax.prng 2023-08-30 15:13:32 -07:00
Skye Wanderman-Milne
f71ba0a2e7 Update versions and changelog post 0.4.15 release 2023-08-30 16:20:25 -04:00
Peter Hawkins
9be96c1d69 Deprecate a number of exports from jax.interpreters.xla.
Custom HLO lowering rules for primitives should be updated to use MLIR StableHLO lowering rules via jax.interpreter.mlir.

PiperOrigin-RevId: 561215967
2023-08-29 20:47:59 -07:00
Peter Hawkins
93900245aa Remove jax.interpreters.xla.register_collective_primitive.
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.

PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Yash Katariya
6072d5993e Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Peter Hawkins
46ac9e2170 Use the default CSR matmul algorithm.
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.

PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Jake VanderPlas
665b176c2c remove deprecated jax.lax.prod function
PiperOrigin-RevId: 559787522
2023-08-24 10:13:59 -07:00
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Jake VanderPlas
19a57e1a01 Deprecate jax.numpy.row_stack 2023-08-22 13:12:49 -07:00
Peter Hawkins
4224a4d129 Deprecate jax.scipy.linalg.tril and jax.scipy.linalg.triu.
The corresponding functions are deprecated in scipy. Use the equivalent jax.numpy functions instead.
2023-08-18 16:14:42 -04:00
George Necula
ad15a38ec1 [host_callback] Remove old backwards compatibility flag jax_host_callback_ad_transforms.
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
2023-08-16 10:01:49 -07:00
George Necula
cf4e1d414b [jax2tf] Bump the default JAX serialization version to 7.
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.

See the CHANGELOG for details.

PiperOrigin-RevId: 556222908
2023-08-11 22:49:31 -07:00
Jake VanderPlas
ad8e719b82 Add jnp.ufunc and jnp.frompyfunc 2023-08-10 14:58:18 -07:00
Peter Hawkins
0e80d959c8 Mark jnp.{NINF,NZERO,PZERO} as deprecated.
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357).

PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Skye Wanderman-Milne
3e50fea29e Remove option to use StreamExecutor Cloud TPU client in JAX
It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023).

PiperOrigin-RevId: 554935166
2023-08-08 14:05:27 -07:00
Jake Vanderplas
d8f799391b COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Peter Hawkins
afd56c15d9 Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.

PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Peter Hawkins
c879f65aa6 [JAX] Remove the non-coordination service distributed service implementation from JAX.
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.

PiperOrigin-RevId: 554608165
2023-08-07 15:17:25 -07:00
George Necula
8d80e2587b [jax2tf] Turn on JAX native serialization by default.
See changes to the README.md for mechanisms to override the default.

PiperOrigin-RevId: 554390866
2023-08-07 01:03:55 -07:00
Patrick Kidger
d6dad3827d Documented the shortening of tracebacks 2023-08-04 12:20:19 -07:00
Skye Wanderman-Milne
011fc88c03 Update versions and changelog for jax 0.4.14 release 2023-07-27 16:22:53 -07:00
jax authors
1ceddfc98a Merge pull request #16710 from gnecula:poly_max0
PiperOrigin-RevId: 549515427
2023-07-19 21:40:17 -07:00