5212 Commits

Author SHA1 Message Date
George Necula
d74e81cc8b
[jax2tf] Disable complex convolution test on GPU: crash in TF (#4319) 2020-09-17 16:31:17 +03:00
George Necula
2ff593747e
[jax2tf] Change precision of test_conv to fix TPU tests (#4317) 2020-09-17 11:45:11 +03:00
Benjamin Chetioui
b1d0f87648
[jax2tf] Group error messages by dtype in pprint_limitations. (#4307)
* [jax2tf] Group error messages by dtype in pprint_limitations.

This makes the output of the categorizer more synthetic in cases
when the error is exactly the same for a given primitive on a set
of devices for different dtypes.
2020-09-17 11:03:31 +03:00
Matthew Johnson
b81c246a18
move the trace liveness check from #4312 (#4315) 2020-09-16 23:59:58 -07:00
Julius Kunze
c6b7269480
Support non-fragmenting mask of reshape (#4264) 2020-09-16 23:58:32 -07:00
Benjamin Chetioui
504e282788
[jax2tf] Fix precision casting problem in convolution. (#4306)
In Python 3.6 (maybe 3.7 too?), the lax.Precision enumeration
was not implicitly casted to int, which made the construction
of the xla_data_pb2.PrecisionConfig object fail in the conversion
of convolution.
2020-09-17 08:41:38 +03:00
Stephan Hoyer
877053d8ab
Add jax.linear_transpose (#3398)
* Add jax.linear_transpose

Co-authored-by: Matthew Johnson <mattjj@google.com>

* add failing test for complex numbers

* Add picky dtype check for linear_transpose

* Lint fix

* Allow truncating dtypes to match inputs in linear_transpose

* Fix typo in shape check error

* improve docstring

* Don't support integer inputs; better docstring

* fixup

* Fix doctest

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-09-16 20:29:19 -07:00
Matthew Johnson
325d3bc71d
improve an escaped tracer error message (#4312)
* improve an escaped tracer error message

Before this commit, encountering an escaped tracer in a specific way
would lead to a bad internal error. This change
1. raises an UnexpectedTracerError instead, and
2. includes in the error message the user source line which created the
tracer.

* deflake

* replace _live propety with _assert_live method

Thanks @jekbradbury !
2020-09-16 15:59:50 -07:00
Jake Vanderplas
e18a973198
implement jnp.apply_over_axes (#4225) 2020-09-16 13:30:08 -07:00
George Necula
dcaa28c624
[jax2tf] More convolution test disabling (#4304) 2020-09-16 15:29:14 +03:00
George Necula
a433c16feb
[jax2tf] Disable some convolution tests (#4303) 2020-09-16 15:08:46 +03:00
Benjamin Chetioui
1f95414f94
[jax2tf] Add tests for the conversion of conv_general_dilated (#4222)
* [jax2tf] Add tests for the conversion of conv_general_dilated.

This also adds the precision argument to the tfxla call which
was previously ignored.

* Separate orthogonal tests.
2020-09-16 11:46:32 +03:00
George Necula
a943056160
[jax2tf] Enable testing for SVD on TPU for float16 (#4288) 2020-09-16 11:44:07 +03:00
Tony
1aeaa4a057
Fix code rendering in optimizers documentation (#4296)
* Fix code rendering in optimizers documentation

* Fix misnamed variable
2020-09-15 23:28:51 -07:00
Jake Vanderplas
19201f4b4a
Mention in docstring when function is not implemented (#4297) 2020-09-15 18:16:36 -07:00
Jake Vanderplas
2811b9b239
Fix device_put_sharded() for concrete values (#4298) 2020-09-15 17:57:54 -07:00
Jake Vanderplas
95287a056c
Add api.device_put_sharded() (#4287) 2020-09-15 16:08:21 -07:00
Matthew Johnson
5520948b0c
tweak traceback for unbound axis names (#4295) 2020-09-15 12:36:53 -07:00
Jake Vanderplas
9c393812c4
jnp.divide: remove obsolete condition for py2 behavior (#4286) 2020-09-15 10:58:47 -07:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Matthew Johnson
6af476900a
update version and changelog for pypi (#4294) jax-v0.1.77 2020-09-15 08:00:47 -07:00
Peter Hawkins
cefa93f2ed
Lower LU decomposition to a custom TPU implementation for float32 types. (#4291) 2020-09-15 09:04:54 -04:00
Benjamin Chetioui
58a117fe0d
Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266)
* Add options to compute L/R eigenvectors in geev.

The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.

Reformulate eig-related operations based on the new geev API.

* Addressed hawkinsp's comments from google/jax#3882.

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-09-15 11:45:15 +03:00
Benjamin Chetioui
a5c2c4729e
[jax2tf] Added support for x64 for the remaining test files (#4282)
* [jax2tf] Added support for x64 in other test files.

This includes:
- control_flow_ops_test.py
- jax2tf_test.py
- saved_model_test.py
- stax_test
2020-09-15 11:40:07 +03:00
Benjamin Chetioui
4e04d4eaf3
[jax2tf] Build a primitive harness for test_type_promotion. (#4279)
* [jax2tf] Build a primitive harness for test_type_promotion.

We were previously generating the cases using `jtu.cases_from_list`,
which by default dropped 2 test cases (JAX_NUM_GENERATED_CASES=10,
number of generated cases = 12).

* [jax2tf] Fix the generated test cases for test_type_promotion.
2020-09-15 09:35:35 +03:00
Trevor Cai
9040336710
Allow device_get to pass Python scalars through unchanged (#4283)
* Allow device_get to pass Python scalars through unchanged

* address comment
2020-09-14 18:35:41 -07:00
Matthew Johnson
7569e80014
revert #4277 (google failure) (#4281)
* revert #4277 (google failure)

Some downstream user is relying on the rank of stax's biases being 1.

* only revert one change
2020-09-14 12:31:51 -07:00
Stephan Hoyer
6bd3216b26
Simplify the interface for host_callback.id_tap (#4101)
* Simplify the internal interface for host_callback.id_tap

This is a breaking change for `id_tap` users (but not `id_print` users).

This makes it easier to use (and type check)  ``tap_func``, because the
expected signature is now ``tap_func(arg, transforms)`` vs
``tap_func(arg, *, transforms, **kwargs)``.

Most of the test changes are just adding whitespace/indentation, but I've
also slightly changed the way transformations are printed.
2020-09-14 12:47:28 +03:00
Benjamin Chetioui
2ff3479239
[jax2tf] Fix tests when running with JAX_ENABLE_X64=1. (#4261)
Fixed tests:
- test_binary_elementwise
- dynamic_update_slice
- fft
- population_count
- test_unary_elementwise
- top_k
- select_and_gather_add
2020-09-14 12:34:31 +03:00
Benjamin Chetioui
7e6d114f7d
[jax2tf] Add converted primitives without tests to the generated doc. (#4248)
* [jax2tf] Add converted primitives without tests to the generated
doc.

* Ignore some primitives in the output of untested primitives.

* Added control_flow_ops_test to template and updated primitives.

* Removed svd from the list of missing tests.

Was just included because I run the tests using
JAX_SKIP_SLOW_TESTS=1, which didn't run the SVD tests. Patched
the generated file manually.
2020-09-14 11:35:43 +03:00
Benjamin Chetioui
fa827a59d2
[jax2tf] Added the last comments from the jax2tf doc inside the (#4249)
correctness_stats code.

In principle, all the relevant documentation that was in the doc
has been moved to the new documentation & comments of categorize.
2020-09-14 11:33:58 +03:00
Peter Buchlovsky
6e2fa39efc
[jax2tf] Fix lax.div_p (#4263) 2020-09-14 11:33:02 +03:00
Roman Novak
9fc4353c5b
Avoid rank promotion in stax biases (#4277)
* Avoid rank promotion in stax biases

* remove itertools
2020-09-13 19:21:17 -07:00
Roman Novak
38b43ef95d
Avoid rank promotion in np.outer (#4276) 2020-09-13 19:20:31 -07:00
Alex Minnaar
64bead2093
fixing typo (#4273)
I assume "...one of more type parameters..." was intended to read "...one or more type parameters..."
2020-09-12 13:10:01 -07:00
Matthew Johnson
f039f6daf9
thread backend in pxla.replicate (#4272)
* thread backend in pxla.replicate

fixes #4223

* add test for #4223
2020-09-11 22:40:12 -07:00
Jake Vanderplas
83b4f3b97c
Cleanup: use _canonicalize_axis() utility where possible (#4270) 2020-09-11 16:49:18 -07:00
Skye Wanderman-Milne
ee9dccf39e
Move failing CPPJitTest test case to PythonJitTest (#4268) 2020-09-11 12:12:34 -07:00
Skye Wanderman-Milne
6dc161cf27
Pin pygments version in RTD build. (#4267)
This fixes our RTD failures, which were caused by RTD installing an older version of pygments:
```
jupyterlab-pygments 0.1.1 requires pygments<3,>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
nbconvert 6.0.1 requires pygments>=2.4.1, but you'll have pygments 2.3.1 which is incompatible.
```
2020-09-11 11:16:54 -07:00
Alvaro
ca1d8f4109
Fixing weird behavior in segment_sum when num_segments is None (#4034)
Co-authored-by: alvarosg <alvarosg@google.com>
2020-09-11 13:51:42 -04:00
Jake Vanderplas
3b7329c92e
Call check_arraylike() in jax.numpy reductions (#4195) 2020-09-11 10:04:47 -07:00
Adam Paszke
40fb01b4bd Extend axis env while translating the pmapped jaxpr to XLA
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
2020-09-11 17:56:32 +02:00
Jake Vanderplas
8a18b10fa4
implement jnp.apply_along_axis (#4253) 2020-09-11 08:47:05 -07:00
Qiao Zhang
7e694bd3b6
Update README to point to jaxlib-0.1.55. (#4256) 2020-09-10 18:28:44 -07:00
Skye Wanderman-Milne
fc39332c4f
Clarify when jaxlib version should be bumped. (#4250) 2020-09-10 15:53:04 -07:00
Jake Vanderplas
2a33b3d388
fix documentation typo (#4252) 2020-09-10 11:23:29 -07:00
Qiao Zhang
82af356b4c
Bump TF hash to get an upstream LLVM GCC fix. (#4251) jaxlib-v0.1.55 2020-09-10 10:10:33 -07:00
Peter Hawkins
cf65f6b24e
Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
2020-09-10 11:16:35 -04:00
Peter Hawkins
b67e42a373
Revert "Revert "Delete batching.last. (#4148)" (#4160)" (#4242)
This reverts commit 36846e0ed96cc613e419ac85d9c3d54a49aa9ebc.
2020-09-10 09:38:14 -04:00
Qiao Zhang
26a53ae554
Add comments for residuals from f_bwd. (#4244) 2020-09-10 13:58:28 +03:00