* Combine the concepts of "platform" and "backend". The main upshot of this is that the tpu_driver backend requires users to write `jit(..., backend="tpu_driver")` if mixing CPU and TPU execution, however I doubt users are writing that because it didn't work to mix CPU and tpu_driver before.
* Initialize all platforms at startup, rather than lazily initializing platforms on demand. This makes it easy to do things like "list the available platforms".
* Don't use two levels of caching. Cache backends only in xla_bridge.py, not xla_client.py.
PiperOrigin-RevId: 376883261
The main problem was that jnp.einsum uses opt_einsum.contract_path
to parse the specification string and compute the order or the
contractions. This function wants to compute the sizes of operands
and intermediate results, and will fail if some dimensions are
polymorphic.
The (partial) solution here is to replace the operands with
jax.ShapeDtypeStruct with a fixed size for all dimension variables,
then call opt_einsum.contract_path and use that result if there
is only one contraction. We abort if there are multiple contractions.
This behavior is clearly sound. If there were multiple contractions,
perhaps their order would be different with different dimension sizes.
--
8226dfc8a4974b4c8031ee267fa5327e778140ee by Nicholas Junge <nicholas.junge@web.de>:
Handle negative values for list-like sections in jnp.split
PiperOrigin-RevId: 376302305
--
746a232632652233f649b15d94f3ed2fd0ccc1fb by George Necula <gcnecula@gmail.com>:
[jax2tf] Updates known limitations.
This PR fixes several issues:
* It updates the documentation of the known limitations
* Increases the numerical tolerance for conv_general_dilated on GPU, to
address test flakiness.
* Adds a workaround for a TF bug that results in a crash when
trying to extract the optimized HLO.
--
4302101aed30a2c7625a2dd5acbe1ca17f9540e4 by George Necula <gcnecula@gmail.com>:
Added limitation for dot_general on GPU
--
207f66a970b7f596e1b265c7aa91fa56e27e7d51 by George Necula <gcnecula@gmail.com>:
Added limitation for dot_general on GPU
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6837 from gnecula:tf_adjust_lim 207f66a970b7f596e1b265c7aa91fa56e27e7d51
PiperOrigin-RevId: 375910042