jax authors
77de0df709
update change log's formatting.
...
PiperOrigin-RevId: 576214943
2023-10-24 11:52:24 -07:00
Peter Hawkins
8d49f9a159
Reverts 9b1a656c1ef1f93f9e93eccb662de9bebe66b51a
...
PiperOrigin-RevId: 576128882
2023-10-24 07:08:34 -07:00
jax authors
9b1a656c1e
Reverts 84c516974ab2b37169938d2d48a6c29a63c62c21
...
PiperOrigin-RevId: 575891656
2023-10-23 12:13:56 -07:00
jax authors
dde17cd5bc
Merge pull request #18180 from carlosgmartin:fill_diagonal
...
PiperOrigin-RevId: 575317151
2023-10-20 14:20:14 -07:00
carlosgmartin
3cb504c583
Add jax.numpy.fill_diagonal.
2023-10-20 16:47:46 -04:00
Yash Katariya
613369fc22
Finish 0.4.19 jax and jaxlib release
...
PiperOrigin-RevId: 574983871
2023-10-19 13:27:52 -07:00
Peter Hawkins
d856ecc6fb
Set RPATH, not RUNPATH in JAX CUDA builds.
...
Fixes https://github.com/google/jax/issues/17497
2023-10-12 09:38:10 -07:00
Jake VanderPlas
117f4bdf9b
Define jax.typing.DTypeLike
2023-10-10 08:46:36 -07:00
Peter Hawkins
4611d13c07
Only perform compilation cache writes from process 0.
...
This avoids problems with contending writes on filesystems such as GCS.
PiperOrigin-RevId: 572032482
2023-10-09 13:55:07 -07:00
Skye Wanderman-Milne
a06beaa1a2
Update versions post jax 0.4.18 release
2023-10-06 17:20:34 -07:00
Skye Wanderman-Milne
d4a1bb9292
Update setup.py and CHANGELOG for jax 0.4.18 release
2023-10-06 13:13:33 -07:00
jax authors
f0e4ea23cc
Merge pull request #17987 from jakevdp:lax-dep
...
PiperOrigin-RevId: 571401660
2023-10-06 12:23:20 -07:00
Peter Hawkins
8c4d020db9
Improve CUDA install documentation.
...
Mention NCCL as a dependency, since it will be required by the next jaxlib release.
Mention LD_LIBRARY_PATH and PATH as how one overrides the CUDA installation for local installs.
Fixes #17831
2023-10-06 14:36:29 -04:00
Jake VanderPlas
ce6a0c43ad
jax.lax: deprecate inadvertent exports & internal utilities
2023-10-06 11:26:03 -07:00
Peter Hawkins
efc18e4147
[JAX] Obtain NCCL via a stub, rather than linking it statically or dynamically.
...
This shrinks the CUDA jaxlib wheel size by around 80MB.
PiperOrigin-RevId: 570554454
2023-10-03 18:33:58 -07:00
Skye Wanderman-Milne
82b58386b7
Update versions and CHANGELOG after jax 0.4.17 release
2023-10-03 17:54:35 -07:00
Jake VanderPlas
a09fdf6e2f
Add jax.numpy.bitwise_count()
2023-10-03 13:48:16 -07:00
Jake VanderPlas
9247a62b2b
Add CHANGELOG entry for the jnp annotation change
2023-10-02 11:31:28 -07:00
Peter Hawkins
b7dfde8d87
Add notes about the new CUDA version restrictions to the changelog and installation instructions.
2023-09-27 15:56:47 -04:00
Peter Hawkins
a2e1f1f24e
Update changelog.
...
Bump the minimum CUDA 12 pip package versions to the current releases.
2023-09-26 18:21:51 -04:00
Peter Hawkins
2fd6df45e4
Fix test failures under SciPy 1.11 for scipy.stats.mode.
2023-09-23 20:15:51 +00:00
Jake VanderPlas
243a6a236c
dtypes.issubdtype: validate a when b is dtypes.extended
2023-09-21 15:53:05 -07:00
Jake VanderPlas
22818d664f
[random] deprecate named key creation functions
2023-09-21 13:57:49 -07:00
Ayaka
74bc42e53e
Fix typo in CHANGELOG.md
2023-09-21 14:37:19 +08:00
Jake VanderPlas
024b1f23d7
Remove deprecated submodule jax.abstract_arrays
2023-09-19 15:40:18 -07:00
Yash Katariya
dcc465b4de
Finish jax and jaxlib 0.4.16 release
...
PiperOrigin-RevId: 566477931
2023-09-18 19:09:19 -07:00
Yash Katariya
a2720ee2c3
Deprecate jax.experimental.pjit.with_sharding_constraint
. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
...
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00
Roy Frostig
1f8cc44f4e
deprecate PRNGKeyArray.unsafe_raw_array
in favor of jax.random.key_data
...
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.
PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Jake VanderPlas
4e6c1b68c7
Deprecate random.KeyArray and random.PRNGKeyArray
2023-09-13 14:05:42 -07:00
Jake VanderPlas
eeb32a7d1f
Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped
...
PiperOrigin-RevId: 565142019
2023-09-13 13:21:46 -07:00
Jake VanderPlas
22ff7bd19a
Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
...
These have been deprecated in JAX following similar deprecations in numpy v1.25.0
PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
408c657436
Add a release note about a fixed Windows crash.
2023-09-07 09:35:25 -04:00
Jake VanderPlas
ca39457ea9
JEX: move jax.linear_util to jax.extend.linear_util
2023-08-30 18:32:12 -07:00
Jake VanderPlas
4b89d03147
Deprecate the contents of jax.prng
2023-08-30 15:13:32 -07:00
Skye Wanderman-Milne
f71ba0a2e7
Update versions and changelog post 0.4.15 release
2023-08-30 16:20:25 -04:00
Peter Hawkins
9be96c1d69
Deprecate a number of exports from jax.interpreters.xla.
...
Custom HLO lowering rules for primitives should be updated to use MLIR StableHLO lowering rules via jax.interpreter.mlir.
PiperOrigin-RevId: 561215967
2023-08-29 20:47:59 -07:00
Peter Hawkins
93900245aa
Remove jax.interpreters.xla.register_collective_primitive.
...
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.
PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Yash Katariya
6072d5993e
Any devices passed to jax.sharding.Mesh are required to be hashable.
...
This is true for mock devices or user specific devices and jax.devices() too.
Fix the tests so that the mock devices are hashable.
PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Peter Hawkins
46ac9e2170
Use the default CSR matmul algorithm.
...
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.
PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Peter Hawkins
975dae34a4
Deprecate jax.numpy.trapz.
...
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.
Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Jake VanderPlas
665b176c2c
remove deprecated jax.lax.prod function
...
PiperOrigin-RevId: 559787522
2023-08-24 10:13:59 -07:00
Peter Hawkins
7c871916f7
Deprecate jax.numpy.in1d.
...
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Jake VanderPlas
19a57e1a01
Deprecate jax.numpy.row_stack
2023-08-22 13:12:49 -07:00
Peter Hawkins
4224a4d129
Deprecate jax.scipy.linalg.tril and jax.scipy.linalg.triu.
...
The corresponding functions are deprecated in scipy. Use the equivalent jax.numpy functions instead.
2023-08-18 16:14:42 -04:00
George Necula
ad15a38ec1
[host_callback] Remove old backwards compatibility flag jax_host_callback_ad_transforms.
...
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.
This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.
PiperOrigin-RevId: 557520668
2023-08-16 10:01:49 -07:00
George Necula
cf4e1d414b
[jax2tf] Bump the default JAX serialization version to 7.
...
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.
See the CHANGELOG for details.
PiperOrigin-RevId: 556222908
2023-08-11 22:49:31 -07:00
Jake VanderPlas
ad8e719b82
Add jnp.ufunc and jnp.frompyfunc
2023-08-10 14:58:18 -07:00
Peter Hawkins
0e80d959c8
Mark jnp.{NINF,NZERO,PZERO} as deprecated.
...
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357 ).
PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Skye Wanderman-Milne
3e50fea29e
Remove option to use StreamExecutor Cloud TPU client in JAX
...
It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023 ).
PiperOrigin-RevId: 554935166
2023-08-08 14:05:27 -07:00
Jake Vanderplas
d8f799391b
COPYBARA_INTEGRATE_REVIEW= https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
...
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00