Ayaka
74bc42e53e
Fix typo in CHANGELOG.md
2023-09-21 14:37:19 +08:00
Jake VanderPlas
024b1f23d7
Remove deprecated submodule jax.abstract_arrays
2023-09-19 15:40:18 -07:00
Yash Katariya
dcc465b4de
Finish jax and jaxlib 0.4.16 release
...
PiperOrigin-RevId: 566477931
2023-09-18 19:09:19 -07:00
Yash Katariya
a2720ee2c3
Deprecate jax.experimental.pjit.with_sharding_constraint
. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
...
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00
Roy Frostig
1f8cc44f4e
deprecate PRNGKeyArray.unsafe_raw_array
in favor of jax.random.key_data
...
The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.
PiperOrigin-RevId: 565195381
2023-09-13 16:33:56 -07:00
Jake VanderPlas
4e6c1b68c7
Deprecate random.KeyArray and random.PRNGKeyArray
2023-09-13 14:05:42 -07:00
Jake VanderPlas
eeb32a7d1f
Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped
...
PiperOrigin-RevId: 565142019
2023-09-13 13:21:46 -07:00
Jake VanderPlas
22ff7bd19a
Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
...
These have been deprecated in JAX following similar deprecations in numpy v1.25.0
PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Peter Hawkins
408c657436
Add a release note about a fixed Windows crash.
2023-09-07 09:35:25 -04:00
Jake VanderPlas
ca39457ea9
JEX: move jax.linear_util to jax.extend.linear_util
2023-08-30 18:32:12 -07:00
Jake VanderPlas
4b89d03147
Deprecate the contents of jax.prng
2023-08-30 15:13:32 -07:00
Skye Wanderman-Milne
f71ba0a2e7
Update versions and changelog post 0.4.15 release
2023-08-30 16:20:25 -04:00
Peter Hawkins
9be96c1d69
Deprecate a number of exports from jax.interpreters.xla.
...
Custom HLO lowering rules for primitives should be updated to use MLIR StableHLO lowering rules via jax.interpreter.mlir.
PiperOrigin-RevId: 561215967
2023-08-29 20:47:59 -07:00
Peter Hawkins
93900245aa
Remove jax.interpreters.xla.register_collective_primitive.
...
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.
PiperOrigin-RevId: 561150259
2023-08-29 15:10:05 -07:00
Yash Katariya
6072d5993e
Any devices passed to jax.sharding.Mesh are required to be hashable.
...
This is true for mock devices or user specific devices and jax.devices() too.
Fix the tests so that the mock devices are hashable.
PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Peter Hawkins
46ac9e2170
Use the default CSR matmul algorithm.
...
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.
PiperOrigin-RevId: 560867049
2023-08-28 17:49:01 -07:00
Peter Hawkins
975dae34a4
Deprecate jax.numpy.trapz.
...
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.
Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
Jake VanderPlas
665b176c2c
remove deprecated jax.lax.prod function
...
PiperOrigin-RevId: 559787522
2023-08-24 10:13:59 -07:00
Peter Hawkins
7c871916f7
Deprecate jax.numpy.in1d.
...
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
Jake VanderPlas
19a57e1a01
Deprecate jax.numpy.row_stack
2023-08-22 13:12:49 -07:00
Peter Hawkins
4224a4d129
Deprecate jax.scipy.linalg.tril and jax.scipy.linalg.triu.
...
The corresponding functions are deprecated in scipy. Use the equivalent jax.numpy functions instead.
2023-08-18 16:14:42 -04:00
George Necula
ad15a38ec1
[host_callback] Remove old backwards compatibility flag jax_host_callback_ad_transforms.
...
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.
This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.
PiperOrigin-RevId: 557520668
2023-08-16 10:01:49 -07:00
George Necula
cf4e1d414b
[jax2tf] Bump the default JAX serialization version to 7.
...
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.
See the CHANGELOG for details.
PiperOrigin-RevId: 556222908
2023-08-11 22:49:31 -07:00
Jake VanderPlas
ad8e719b82
Add jnp.ufunc and jnp.frompyfunc
2023-08-10 14:58:18 -07:00
Peter Hawkins
0e80d959c8
Mark jnp.{NINF,NZERO,PZERO} as deprecated.
...
This follows the upstream NumPy deprecation of these names (https://github.com/numpy/numpy/pull/24357 ).
PiperOrigin-RevId: 555548986
2023-08-10 10:25:21 -07:00
Skye Wanderman-Milne
3e50fea29e
Remove option to use StreamExecutor Cloud TPU client in JAX
...
It's been over three months since the new PJRT C API client was
enabled by default
(https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023 ).
PiperOrigin-RevId: 554935166
2023-08-08 14:05:27 -07:00
Jake Vanderplas
d8f799391b
COPYBARA_INTEGRATE_REVIEW= https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31
...
PiperOrigin-RevId: 554905739
2023-08-08 20:38:44 +00:00
Peter Hawkins
afd56c15d9
Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target.
...
Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration.
PiperOrigin-RevId: 554861284
2023-08-08 10:09:09 -07:00
Peter Hawkins
c879f65aa6
[JAX] Remove the non-coordination service distributed service implementation from JAX.
...
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.
PiperOrigin-RevId: 554608165
2023-08-07 15:17:25 -07:00
George Necula
8d80e2587b
[jax2tf] Turn on JAX native serialization by default.
...
See changes to the README.md for mechanisms to override the default.
PiperOrigin-RevId: 554390866
2023-08-07 01:03:55 -07:00
Patrick Kidger
d6dad3827d
Documented the shortening of tracebacks
2023-08-04 12:20:19 -07:00
Skye Wanderman-Milne
011fc88c03
Update versions and changelog for jax 0.4.14 release
2023-07-27 16:22:53 -07:00
jax authors
1ceddfc98a
Merge pull request #16710 from gnecula:poly_max0
...
PiperOrigin-RevId: 549515427
2023-07-19 21:40:17 -07:00
Jake VanderPlas
7205160095
Re-parameterize jax.random.gamma for better behavior at endpoints
2023-07-19 16:15:03 -07:00
jax authors
5ae3ac28cd
Add deprecation of jax.stages.Compiled.compiler_ir
to the change log
...
PiperOrigin-RevId: 549415191
2023-07-19 13:48:55 -07:00
George Necula
e643f98558
[shape_poly] Reimplement the shape constraint checking using shape assertions.
...
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.
As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
2023-07-19 09:56:33 +03:00
Peter Hawkins
59509dc2b3
Remove the jax_array config option, which does nothing.
...
PiperOrigin-RevId: 548981491
2023-07-18 06:16:06 -07:00
Yash Katariya
f0ce0d8c6a
Delete in_axis_resources and out_axis_resources from pjit since it's been more than 3 months since their deprecation. The replace is to use in_shardings and out_shardings. You can still pass PartitionSpecs to {in|out}_shardings to pjit.
...
PiperOrigin-RevId: 548673905
2023-07-17 06:35:49 -07:00
George Necula
603eeb1901
Copybara import of the project:
...
--
06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula <gcnecula@gmail.com>:
[jax2tf] Added a flag and environment variable to control the serialization version.
This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c
PiperOrigin-RevId: 548504243
2023-07-16 09:27:12 -07:00
Peter Hawkins
651f87733b
Remove jax_jit_pjit_api_merge.
...
PiperOrigin-RevId: 548236671
2023-07-14 15:25:00 -07:00
Yash Katariya
89c78bf53f
jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
...
Update the docstring and changelog too to mention `donate_argnames`.
PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
Jake VanderPlas
21f6736005
Remove several deprecated APIs
2023-07-11 12:42:32 -07:00
Jake VanderPlas
b581ad1f33
Remove several deprecated jax.Array methods:
...
- arr.broadcast
- arr.broadcast_in_dim
- arr.split
These have been deprecated since JAX v0.4.5
PiperOrigin-RevId: 547228974
2023-07-11 10:34:27 -07:00
Jake VanderPlas
9962065deb
Require ml_dtypes>=0.2
2023-07-07 12:07:44 -07:00
Jake VanderPlas
d0e75ca117
Require index update optional arguments to be passed by keyword.
...
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)
PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00
Roy Frostig
48903a382e
add corner-case cond
resolution fix to changelog
2023-06-28 10:09:10 -07:00
Jake VanderPlas
3f47ad367d
jax.interpreters.pxla: remove deprecated functions:
...
- jax.interpreters.pxla.device_put
- jax.interpreters.pxla.make_sharded_device_array
2023-06-27 21:49:55 -07:00
Yash Katariya
c632cace1e
Raise an error if a user passes None to host_local_array_to_global_array
or global_array_to_host_local_array
...
PiperOrigin-RevId: 543596009
2023-06-26 18:15:43 -07:00
Jake VanderPlas
ad35702934
Drop support for numpy 1.21
...
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
19890086fa
[Rollback] Remove py3.8 support from jax as per https://jax.readthedocs.io/en/latest/deprecation.html
...
PiperOrigin-RevId: 542724110
2023-06-22 18:31:30 -07:00