This fixes our RTD failures, which were caused by RTD installing an older version of pygments:
```
jupyterlab-pygments 0.1.1 requires pygments<3,>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
nbconvert 6.0.1 requires pygments>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
```
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
This fixes some errors that have been appearing in our CI from time to
time. All transformations are implemented as generators, but they
haven't been explicitly aborted when an exception has been raised.
Instead, they only got closed when they got garbage collected, which
could happen at an unspecified later time, potentially leading to a
corruption of global state, which could have been modified after the
exception was handled.
Note that this implementation doesn't propagate the original exception
into the argument transformations, and doesn't allow them to handle the
error either. Such an extension would be possible, but throwing an
exception into a generator mutates the exception object, clobbering
the nice traceback that we would usually carry. One can work around
those issues, but it feels really hacky and we don't need it right now
anyway, so I figured we'll be better off with the simple thing for the
time being.
* [jax2tf] Add a template file for documentation generation.
The documentation now gives instructions about how to
regenerate it, as well as when it was last generated.
* Added a list of conversions that are not yet implemented.
* [jax2tf] Clean up code for XlaGather, experimental_compile not necessary
Now that XlaGather has been fixed in XLA, we do not need to use
experimental_compile workaround (which was not working anyway when
put in a SavedModel).
This fix requires a recent tf-nightly installation.
1. `wheel.pep425tags` has been removed as of
https://github.com/pypa/setuptools/pull/1829. Use the new
`packaging.tags` instead.
2. Add `--allow-downgrades` to cuda install command. I'm not sure this
is always necessary, but I ran into it, I'm guessing due to a cached
docker image.
This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.)
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.
* [jax2tf] Expand coverage of primitives by categorize.
This commit adds handling logic for the limitations of:
- qr
- svd
- select_and_gather_add
- reduce_window/reduce_window_{min,max,sum}
- add
- mul
- scatter/scatter_{min,max,mul,add}
Also fixes a bug in a call to _infer_shape_jax, which wasn't
compatible with boolean operands and went undetected due to the
high-level handling of TF exceptions in higher-order primitives.
Since we do the threefry with signed integers when converting to TF,
we run into the type promotion 'uint32 - int32 = int64', which
then results in lax.shift_right_logical(uint32, int64), which fails.
* [jax2tf] Cleanup the correctness stats layout.
* Added Google license at the top of the file.
* Cleanup: fix docstring for 80 char boundary.
* Monkey patch/cleanup outside of the loop.
* Removed tensorflow dependency.
* Fixed the name of attributes of Limitation.
* [jax2tf] implementation of random_gamma
The simplest implementation is by converting the JAX own impl_rule,
which rewrites gamma into other JAX primitives.
On TPU with use_vmap=True the performance is the same for JAX and TF, provided
we use tf.function(compile=True).
This mainly follows https://github.com/google/jax/pull/4089 by adding:
- support for disable_jit from C++
- support for jax._cpp_jit on methods.
- supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend.
- concurrency support.
I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.)
See:
- https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see
cr/328899906 + benchmarks for how numbers were generated)
- The results of the Jax tests when enabling this:
http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure).
A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes).