The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
This fixes some errors that have been appearing in our CI from time to
time. All transformations are implemented as generators, but they
haven't been explicitly aborted when an exception has been raised.
Instead, they only got closed when they got garbage collected, which
could happen at an unspecified later time, potentially leading to a
corruption of global state, which could have been modified after the
exception was handled.
Note that this implementation doesn't propagate the original exception
into the argument transformations, and doesn't allow them to handle the
error either. Such an extension would be possible, but throwing an
exception into a generator mutates the exception object, clobbering
the nice traceback that we would usually carry. One can work around
those issues, but it feels really hacky and we don't need it right now
anyway, so I figured we'll be better off with the simple thing for the
time being.
* [jax2tf] Add a template file for documentation generation.
The documentation now gives instructions about how to
regenerate it, as well as when it was last generated.
* Added a list of conversions that are not yet implemented.
* [jax2tf] Clean up code for XlaGather, experimental_compile not necessary
Now that XlaGather has been fixed in XLA, we do not need to use
experimental_compile workaround (which was not working anyway when
put in a SavedModel).
This fix requires a recent tf-nightly installation.
This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.)
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.
* [jax2tf] Expand coverage of primitives by categorize.
This commit adds handling logic for the limitations of:
- qr
- svd
- select_and_gather_add
- reduce_window/reduce_window_{min,max,sum}
- add
- mul
- scatter/scatter_{min,max,mul,add}
Also fixes a bug in a call to _infer_shape_jax, which wasn't
compatible with boolean operands and went undetected due to the
high-level handling of TF exceptions in higher-order primitives.
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
* [jax2tf] Cleanup the correctness stats layout.
* Added Google license at the top of the file.
* Cleanup: fix docstring for 80 char boundary.
* Monkey patch/cleanup outside of the loop.
* Removed tensorflow dependency.
* Fixed the name of attributes of Limitation.
* [jax2tf] implementation of random_gamma
The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.
On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
This mainly follows https://github.com/google/jax/pull/4089 by adding:
- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.
I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)
See:
- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes).
As I was writing the demo I realized that it makes more sense for
with_gradient to be set to True by default.
I have also fixed a bug with tie_in in omnistaging.
* applied simple find+sed for 'master' -> 'main'
* Rename master->main in JAX API and internals (#4178)
* Started with #4174
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master
Co-authored-by: George Necula <gcnecula@gmail.com>
* Implement a proper shape checking rule for gather.
The implementation is based on the corresponding shape inference
code in `tensorflow/compiler/xla/service/shape_inference.cc`. The
tests added in `tests/lax_test.py` are similarly mirroring the
corresponding tests in tensorflow, with slight adaptations for the
particular setting of JAX. Fixesgoogle/jax#2826, and in principle
fixesgoogle/jax#4154 and google/jax#3905.
* Extracted common functions for gather/scatter shape checking rules.
* Addition of one more conclusive polynomial comparison case.
In the case when the difference between two polynomials is a
constant, it is possible to conclusively compare them. This commit
adds such a case to masking.Poly.__ge__.
* Added a few relevant tests in tests.masking_test.test_Poly_compare.
This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.
This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
* [jax2tf] Added conversion for scatter*_p primitives.
Limitations:
the conversion works as well as the conversion of the underlying reduction functions (e.g. lax.scatter_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors);
the conversion can not take advantage of the unique_indices parameter. This does not affect correctness, but may affect performance on certain platforms (as stated in the documentation of lax.scatter).
* Put tf.function experimental compile wrapper back on scatter.
* Removed unique_indices=True test cases
* Remove non-deterministic test cases from the scatter harness.
This commit also documents the reasons for ignoring these test
cases and potential pitfalls, in case someone needs to perform
these tests at a later time.
The implementation is based on the corresponding shape inference
code in `tensorflow/compiler/xla/service/shape_inference.cc`. The
tests added in `tests/lax_test.py` are similarly mirroring the
corresponding tests in tensorflow, with slight adaptations for
the particular setting of JAX.
* Add a boolean to _check_shapelike to accept or reject shapes
corresponding to arrays of 0 elements. (Fixesgoogle/jax#3972).
* Added test for failures referenced in issue 3972.
* [jax2tf] Add testing for add/mul/min/max conversion.
Only certain types are supported for each of the operations above.
This commit adds previously missing tests to make this explicit.
This code was failing with "KeyError: psum" for the tests
"//third_party/py/flax/...". I suspect that the error is due to the
ordering of the omnistaging enablers, changed in #4152.
I am not sure of this fix, but this seemed to be enough for all the
presubmit tests to pass and allow the copybara import.