The latter function is also better in that its behavior is invariant to `jit`,
whereas the `unsafe_raw_array` method only works in eager mode.
PiperOrigin-RevId: 565195381
We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless.
PiperOrigin-RevId: 561150259
This is true for mock devices or user specific devices and jax.devices() too.
Fix the tests so that the mock devices are hashable.
PiperOrigin-RevId: 561103167
Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always.
PiperOrigin-RevId: 560867049
This flag was added in https://github.com/google/jax/pull/8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.
This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.
PiperOrigin-RevId: 557520668
This enables shape assertion checking, the support for which
landed in XlaCallModule on July 12th, 2023.
See the CHANGELOG for details.
PiperOrigin-RevId: 556222908
The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code.
PiperOrigin-RevId: 554608165
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.
As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
--
06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula <gcnecula@gmail.com>:
[jax2tf] Added a flag and environment variable to control the serialization version.
This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c
PiperOrigin-RevId: 548504243