* remove the dead code KeyTangentTy
* replace TyRules.make_tangent with TyRules.zero
* removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it
* fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type
* fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change
We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though!
I think a31129a aka cl/587963496 accidentally made hypothesis a test dependency in tests/all_gather_test.py, rather than following our existing convention as in tests/state_test.py of making it optional. I think it was an accident because there's no discussion of adding hypothesis as a test dependence on the review for that PR/CL.
This PR changes tests/all_gather_test.py to follow the convention for making hypothesis optional.
The guarantee provided by this API is that the sharding of `x` and `y` will be the same! What the sharding will be is decided by GSPMD.
The flow of sharding is bidirectional i.e. SPMD will choose what the sharding should be of `x` and `y` depending on it's propagation algorithm. It might end up being that the sharding chosen is not of `x` and `y` but something better. At the end of propagation `x` and `y` will be sharded alike.
The API can be made variadic in the future i.e. `*args = shard_alike(*args)` depending on use cases.
Fixes: https://github.com/google/jax/issues/15600
PiperOrigin-RevId: 592375936
This was deprecated prior to the JAX 0.4.16 release, so we have now met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).
PiperOrigin-RevId: 592266215
Previously the environment variable JAX_DUMP_IR_TO controlled
whether and where to dump the MLIR module prior to compilation. Now we move the code for that support from
compiler.py to mlir.py, so that it can be used in other
parts of the code. We also add support for logging to Sponge.
Using this support we now log the module on errors from
refine_polymorphic_shapes.
PiperOrigin-RevId: 592099633
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.
This is part of a larger cl/562671314 that ran into OSS build problems.
This is step 2: moves the other test data Python files.
PiperOrigin-RevId: 591934999
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).
PiperOrigin-RevId: 591933493
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.
This is part of a larger cl/562671314 that ran into OSS build problems.
I am attempting this smaller change first, and afterwards I will move more of the test data files, and then the actual test.
PiperOrigin-RevId: 591927484
There were two problems:
* the float0 dtype was not part of the schema,
* there was a bug invoking jax.vjp on a reloaded
function, because of a mismatch between the type
of symbolic zeros.
We changed the schema to add `f0`, but we add that
enum with a value larger than existing values, to
preserve backwards compatibility.