702 Commits

Author SHA1 Message Date
jax authors
c8f5b2bb13 Merge pull request #24481 from jakevdp:key-array-error
PiperOrigin-RevId: 694626415
2024-11-08 13:47:05 -08:00
Jake VanderPlas
83383fc717 Error on numpy array conversion of PRNG key array 2024-11-07 10:08:49 -08:00
Jake VanderPlas
1af3b01c1c register_dataclass: allow marking static fields via field(static=True) 2024-11-06 11:18:11 -08:00
Jake VanderPlas
095bb0e742 Make Tracers non-hashable 2024-11-05 09:08:33 -08:00
Jake VanderPlas
e9acaa8484 Remove the initial argument to jax.nn.softmax and jax.nn.log_softmax.
This argument was deprecated in JAX v0.4.27 and has no effect in JAX v0.4.27 and later.

PiperOrigin-RevId: 693023366
2024-11-04 10:55:21 -08:00
George Necula
292a00b35a [export] Cleanup in the export module.
With jax.experimental.export gone we can now do some cleanup in the export module.

In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.

PiperOrigin-RevId: 692398132
2024-11-01 22:56:44 -07:00
Matthew Johnson
0f3ba4250d support exec_time_optimization_effort and memory_fitting_effort xla compilation
options

PiperOrigin-RevId: 692322944
2024-11-01 16:25:50 -07:00
Jake VanderPlas
e61a20b45a Remove deprecated jax.experimental.export module.
These tools are now available at jax.export.
2024-10-30 05:27:29 -07:00
Jake VanderPlas
d4c46825d6 Finalize deprecation of xb, xc, & xe symbols in jax.interpreters.xla
PiperOrigin-RevId: 689792265
2024-10-25 08:12:44 -07:00
George Necula
c62b19883f Fix copy and paste error in CHANGELOG. 2024-10-25 16:11:35 +03:00
George Necula
9088adda68 [jax2tf] Disable jax2tf with non-native serialization.
jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.

This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.

PiperOrigin-RevId: 689708392
2024-10-25 02:30:54 -07:00
Peter Hawkins
2aeda17829 Merge branch 'release/0.4.35' 2024-10-23 08:50:31 -04:00
Peter Hawkins
e4f3f8f064 Use libtpu releases rather than libtpu-nightly for jax[tpu].
PiperOrigin-RevId: 688632409
2024-10-22 11:47:07 -07:00
Peter Hawkins
e9c7ff0b7d Deprecate a number of APIs in jax.lib.xla_client.
(Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.)

PiperOrigin-RevId: 684906277
2024-10-11 11:42:40 -07:00
Dan Foreman-Mackey
f55141ef0e Fix listing of vectorized deprecation in changelog.
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
2024-10-10 15:40:01 -04:00
Peter Hawkins
aa3254d723 Deprecate jax.lib.xla_client.PaddingType.
This type is unused by JAX, so there is no replacement.

(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)

PiperOrigin-RevId: 684451556
2024-10-10 08:22:20 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Yuxuan Jiang
757a77ede0
Fix wrong date in changelog 2024-10-06 23:16:30 +08:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
Jake VanderPlas
45f0e9ad68 Simplify definition of jnp.isscalar
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
2024-10-05 07:12:20 -07:00
Peter Hawkins
b0b7a60e63 Merge branch 'release/0.4.34' 2024-10-04 10:56:18 -04:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Jake VanderPlas
49ad220e57 Finalize deprecation of XLACompatibleSharding
PiperOrigin-RevId: 681156145
2024-10-01 14:02:34 -07:00
George Necula
2228115cf4 [host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False
`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
2024-10-01 07:07:29 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
Peter Hawkins
0e082f978b Deprecate jax.lib.xla_client.Device.
jax.Device is a longstanding public name for this class.

PiperOrigin-RevId: 679197718
2024-09-26 10:17:04 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Jake VanderPlas
e05c37c667 Finalize deprecation of pretty-printing utils in jax.core.pp_*
PiperOrigin-RevId: 678775782
2024-09-25 11:20:35 -07:00
Peter Hawkins
111f13e279 Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
PiperOrigin-RevId: 678748794
2024-09-25 10:14:45 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
6a3736a1d7 Add a note to the changelog about the new CPU thunks backend, enabled in 0.4.32. 2024-09-19 15:38:52 -04:00
Peter Hawkins
bef36c431d Add Python 3.13 wheels to changelog. 2024-09-18 18:57:03 +00:00
rajasekharporeddy
2714469397 Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros 2024-09-18 17:06:28 +05:30
Peter Hawkins
ae0e403c60 Merge release/0.4.33 into main and update version numbers. 2024-09-16 18:46:24 +00:00
Peter Hawkins
80e1c94de6 Prepare for v0.4.33 release.
This release is branched off the v0.4.32 release, with two changes:
a) a fixed libtpu pin, and
b) a patch to revert an F64 tanh issue on CPU.
2024-09-16 13:30:35 +00:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Peter Hawkins
dffac29e63 Reverts 255c30303d32e7473262b2e35348175c87e4348f
PiperOrigin-RevId: 674083626
2024-09-12 18:14:25 -07:00
Peter Hawkins
255c30303d Fix a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

PiperOrigin-RevId: 674019460
2024-09-12 14:49:18 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
Parker Schuh
bf2237a102 Flip jax_pmap_no_rank_reduction by default to True.
This changes:
* The performance of array[0] (use array[0:1] instead).
* The shape of jax_array.addressable_shards or jax_array.addressable_data(0) of arrays that come from pmap.

PiperOrigin-RevId: 673564995
2024-09-11 15:41:47 -07:00
Peter Hawkins
3e81ae530d Update version numbers after v0.4.32 release. 2024-09-11 16:18:56 -04:00
Dan Foreman-Mackey
bcbc0962bb Add the FFI functions and tutorial to the changelog.
Although we soft launched the FFI with v0.4.31, it would be nice to
include an update in the changelog to help with visibility.
2024-09-06 12:30:28 -04:00
Peter Hawkins
9c86fdec02 Make optimization_barrier a public lax API. 2024-09-06 00:18:57 +00:00
Sergei Lebedev
1289640f09 Deprecated calling `jax.dlpack.from_dlpack` with a DLPack tensor
PiperOrigin-RevId: 670723176
2024-09-03 15:16:02 -07:00
Jake VanderPlas
f2ffe7f8f2 Deprecate jax.numpy.round_
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Jake VanderPlas
a3d6cf007e First pass at ufunc interfaces for several jax.numpy functions 2024-08-30 11:53:02 -07:00
Sergei Lebedev
02bb884357 `jax.tree_util.register_dataclass now validates data_fields and meta_fields`
A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity

PiperOrigin-RevId: 669244669
2024-08-30 02:01:50 -07:00
rajasekharporeddy
ced012f5ed Update jnp.fabs to emulate the behavior of np.fabs for complex inputs 2024-08-28 20:16:09 +05:30
Bryan Massoth
b38f985b01 Add a callout that LibTPU now supports profiling of SparseCore for TPUv5p chips which will be viewable in Tensorboard Profiler's TraceViewer tool.
PiperOrigin-RevId: 667708094
2024-08-26 14:04:43 -07:00