George Necula
235eb8c2b4
Copybara import of the project:
...
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
Jake VanderPlas
71a25cdac1
DOC: add examples to lax function docstrings
2021-04-29 09:48:52 -07:00
Jake VanderPlas
ca684df0e9
DOC: add example for lax.dynamic_update_slice
2021-04-23 09:10:43 -07:00
Lukas Geiger
f7f42694d9
Add support for preferred_element_type
arg in convolutions
2021-04-22 10:29:31 +02:00
Skye Wanderman-Milne
feb79e5698
Fix some Cloud TPU test failures.
...
The new select_and_gather_add logic was inspired by
3a35f7072a
.
2021-04-21 00:37:02 +00:00
Lena Martens
fa5e19b630
Fix Zero handling in select_jvp.
2021-04-19 17:03:07 +01:00
Peter Hawkins
14d991dd90
Move jax.config to jax._src.config.
...
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Matthew Johnson
9d6263a743
support implicit broadcasting in transpose rules
2021-04-16 12:51:11 -07:00
Peter Hawkins
26e9ebcdae
Move jax.api to jax._src.api.
...
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Peter Hawkins
0f1520b6d2
Enable variadic select_and_gather on TPU.
2021-04-13 09:09:10 -04:00
jax authors
8f2502324a
Merge pull request #6408 from LenaMartens:changelist/367979622
...
PiperOrigin-RevId: 367991796
2021-04-12 06:42:24 -07:00
jax authors
ce67e563a1
Merge pull request #6375 from gnecula:mask_clean
...
PiperOrigin-RevId: 367985125
2021-04-12 05:50:19 -07:00
Lena Martens
b4f66d2676
Fix handling of ad.Zero in _select_and_scatter_add_transpose.
...
Fixes #6403 .
2021-04-12 13:07:40 +01:00
Peter Hawkins
3a35f7072a
Implement select_and_gather_add using variadic reducewindow on CPU.
2021-04-09 14:40:43 -04:00
jax authors
438b56c483
Fix typo in rng_bit_generator comment.
...
PiperOrigin-RevId: 367460802
2021-04-08 10:42:45 -07:00
George Necula
0e280bbac0
[masking] Remove references to masking.Poly from the lax.py and lax_numpy.py
...
Previously, in order to increase the coverage of masking we added special
cases in lax.py and lax_numpy.py to avoid exceptions in presence of
masking.Poly.
For example:
```
if not isinstance(d, masking.Poly):
if some_check(d):
raise ValueError
```
All such conditionals make the code behave potentially different when
tracing with masking.Poly than when tracing with concrete shapes, which
makes it hard to ensure soundness.
Perhaps the most eggregious was:
```
if type(i) is Poly:
# dummy index if i is polynomial, doesn't matter for shape inference
i = 0
```
2021-04-08 17:45:14 +03:00
jax authors
3a9ce3990e
Merge pull request #6345 from gnecula:shape_poly
...
PiperOrigin-RevId: 367416742
2021-04-08 06:21:12 -07:00
George Necula
2e9e824289
Cleanup and fix triangular_solve
2021-04-08 10:42:38 +03:00
George Necula
99d5f09b29
Fix select and eigh
2021-04-08 10:42:38 +03:00
George Necula
5750ec074a
Fix scatter
2021-04-08 10:42:38 +03:00
George Necula
551a89cfe9
Fixes for slice
2021-04-08 10:42:38 +03:00
George Necula
cbe5f54cca
Added support for lax.pad, and more error checking
2021-04-08 10:42:38 +03:00
George Necula
4f9ac031d7
Add some support for convolutions
2021-04-08 10:42:38 +03:00
George Necula
56e41b7cd7
Add support for cummax
2021-04-08 10:42:38 +03:00
George Necula
e37727cbce
[jax2tf] Implementation of a parametric shape-polymorphism feature for jax2tf.
...
See the PR description.
2021-04-08 10:42:38 +03:00
Peter Hawkins
6a6f13e1b0
[JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
...
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
jax authors
42e01ee2fa
Merge pull request #6339 from google:djax-dot
...
PiperOrigin-RevId: 367134944
2021-04-06 19:34:47 -07:00
Matthew Johnson
abf2d69262
djax: add analogue of lower_fun, dot
2021-04-06 13:17:13 -07:00
Jake VanderPlas
93a1882e86
lax.reduce_precision: add basic batching & masking rule
2021-04-05 14:16:50 -07:00
Jake VanderPlas
8e789c7380
Run doctest on all source files except jax2tf
2021-04-05 10:39:59 -07:00
Jake VanderPlas
33fde77bb1
Add lax.reduce_precision()
2021-04-05 09:54:14 -07:00
Matthew Johnson
9205b6f125
remove dead lax._dot_using_sum_of_products
2021-04-02 14:23:02 -07:00
jax authors
1f1d3dffe2
Merge pull request #6182 from hawkinsp:reduce
...
PiperOrigin-RevId: 365902926
2021-03-30 14:53:36 -07:00
Peter Hawkins
3fc1fdb148
Add a JVP rule for the general case of lax.reduce.
2021-03-30 17:31:47 -04:00
Matthew Johnson
aa2472db0c
add scatter_add jet rule, fixes #5365
...
could use a better test though...
2021-03-30 14:04:40 -07:00
Matthew Johnson
2b79264354
remove disable_omnistaging mechanism
2021-03-29 15:26:57 -07:00
jax authors
2022141b13
Merge pull request #6208 from majnemer:int-conv
...
PiperOrigin-RevId: 365544250
2021-03-29 04:20:05 -07:00
Matthew Johnson
8547c71bfd
simplify public lax.convert_element_type api
...
Specifically:
1. don't expose weak_type in the public api, as it's jax-internal
2. don't make new_dtype optional, which could make bugs easier
This change keeps the public API simpler, and also makes
convert_element_type match the ConvertElementType HLO. As an internal
API we can call lax._convert_element_type just like before.
2021-03-28 10:32:02 -07:00
David Majnemer
7defa05009
Allow integer/boolean convolutions
2021-03-24 23:20:30 -07:00
Matthew Johnson
89768a3d28
add jax_default_matmul_precision flag & context mngr
2021-03-24 14:03:58 -07:00
Matthew Johnson
214d273d8c
undo changes to host_callback (not needed anymore)
2021-03-21 19:43:12 -07:00
Matthew Johnson
fe4d12c10f
move logic to traceable
2021-03-21 19:38:12 -07:00
Matthew Johnson
8c3125c172
fix convert_element_type on large Py int inputs
2021-03-21 19:08:59 -07:00
Matthew Johnson
af59542d00
Re-applying the changes in #6014 , after they had to be rolled-back.
...
PiperOrigin-RevId: 364200195
2021-03-21 13:40:20 -07:00
Matthew Johnson
57d5c6af5f
add clz primitive
2021-03-19 22:54:36 -07:00
Roy Frostig
7427991819
skip scalars when broadcasting for batch dimension agreement
2021-03-19 21:47:16 -07:00
jax authors
4f8814a760
Copybara import of the project:
...
--
bf15ba5310d5f9009571928f70548bcbc7e856c3 by Matthew Johnson <mattjj@google.com>:
don't device transfer in convert_element_type
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
PiperOrigin-RevId: 363995032
2021-03-19 16:35:37 -07:00
Matthew Johnson
bf15ba5310
don't device transfer in convert_element_type
...
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2021-03-19 13:42:33 -07:00
Jake VanderPlas
5f51d4fb1d
Make lax._const() work for non-canonical dtypes
2021-03-17 13:07:53 -07:00
Peter Hawkins
328930b917
Increase minimum jaxlib version to 0.1.62.
2021-03-16 15:11:36 -04:00