* [jax2tf] Group error messages by dtype in pprint_limitations.
This makes the output of the categorizer more synthetic in cases
when the error is exactly the same for a given primitive on a set
of devices for different dtypes.
In Python 3.6 (maybe 3.7 too?), the lax.Precision enumeration
was not implicitly casted to int, which made the construction
of the xla_data_pb2.PrecisionConfig object fail in the conversion
of convolution.
* Add jax.linear_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
* add failing test for complex numbers
* Add picky dtype check for linear_transpose
* Lint fix
* Allow truncating dtypes to match inputs in linear_transpose
* Fix typo in shape check error
* improve docstring
* Don't support integer inputs; better docstring
* fixup
* Fix doctest
Co-authored-by: Matthew Johnson <mattjj@google.com>
* improve an escaped tracer error message
Before this commit, encountering an escaped tracer in a specific way
would lead to a bad internal error. This change
1. raises an UnexpectedTracerError instead, and
2. includes in the error message the user source line which created the
tracer.
* deflake
* replace _live propety with _assert_live method
Thanks @jekbradbury !
* [jax2tf] Add tests for the conversion of conv_general_dilated.
This also adds the precision argument to the tfxla call which
was previously ignored.
* Separate orthogonal tests.
* Add options to compute L/R eigenvectors in geev.
The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.
Reformulate eig-related operations based on the new geev API.
* Addressed hawkinsp's comments from google/jax#3882.
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
* [jax2tf] Build a primitive harness for test_type_promotion.
We were previously generating the cases using `jtu.cases_from_list`,
which by default dropped 2 test cases (JAX_NUM_GENERATED_CASES=10,
number of generated cases = 12).
* [jax2tf] Fix the generated test cases for test_type_promotion.
* Simplify the internal interface for host_callback.id_tap
This is a breaking change for `id_tap` users (but not `id_print` users).
This makes it easier to use (and type check) ``tap_func``, because the
expected signature is now ``tap_func(arg, transforms)`` vs
``tap_func(arg, *, transforms, **kwargs)``.
Most of the test changes are just adding whitespace/indentation, but I've
also slightly changed the way transformations are printed.
* [jax2tf] Add converted primitives without tests to the generated
doc.
* Ignore some primitives in the output of untested primitives.
* Added control_flow_ops_test to template and updated primitives.
* Removed svd from the list of missing tests.
Was just included because I run the tests using
JAX_SKIP_SLOW_TESTS=1, which didn't run the SVD tests. Patched
the generated file manually.
correctness_stats code.
In principle, all the relevant documentation that was in the doc
has been moved to the new documentation & comments of categorize.
This fixes our RTD failures, which were caused by RTD installing an older version of pygments:
```
jupyterlab-pygments 0.1.1 requires pygments<3,>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
nbconvert 6.0.1 requires pygments>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
```
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.