These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.
This is part of a larger cl/562671314 that ran into OSS build problems.
This is step 2: moves the other test data Python files.
PiperOrigin-RevId: 591934999
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).
PiperOrigin-RevId: 591933493
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.
This is part of a larger cl/562671314 that ran into OSS build problems.
I am attempting this smaller change first, and afterwards I will move more of the test data files, and then the actual test.
PiperOrigin-RevId: 591927484
There were two problems:
* the float0 dtype was not part of the schema,
* there was a bug invoking jax.vjp on a reloaded
function, because of a mismatch between the type
of symbolic zeros.
We changed the schema to add `f0`, but we add that
enum with a value larger than existing values, to
preserve backwards compatibility.
batch_device_put in libtpu will go through memory space path, and requires CopyToMemorySpace to be implemented if the backend uses memory space.
PiperOrigin-RevId: 591373300
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:
```
from jax.experimental import export
exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1 = export.deserialized(ser)
export.call(exp1)
```
This change also includes a workaround to allow users to still
do `from jax.experimental.export import export`, for a while.
We currently support only the serialization of effects with
nullary constructors. We must also ensure that upon deserialization
we produce an event that tests equal to the original one.
Here we add explicit error checks and tests.
We also make the CallTfEffect to have this property.
I used the same implementation technique in shard_map.py, e.g. in ShardMapTrace.process_custom_jvp_call, and it's sound, whereas I can't remember why we implementd the eager pmap stuff the way we did.
This fixes an internal test, but unfortunately I wasn't able to figure out a simple repro :/