9784 Commits

Author SHA1 Message Date
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
jax authors
b10a306266 Merge pull request #8522 from marcvanzee:compare
PiperOrigin-RevId: 412012879
2021-11-24 03:17:04 -08:00
jax authors
3c34b62c66 Merge pull request #8676 from gnecula:clean_hcb
PiperOrigin-RevId: 412006660
2021-11-24 02:44:31 -08:00
George Necula
0915f6d6fa mend 2021-11-24 11:57:28 +02:00
George Necula
277a1d775e [hcb] Cleanup to account for changes in minimum jaxlib version
We can assume now that jaxlib has the support for CustomCall.
2021-11-24 11:47:11 +02:00
Peter Hawkins
839d410de0 [MLIR] Move most MLIR translation rules into lax.
PiperOrigin-RevId: 411942327
2021-11-23 18:58:28 -08:00
jax authors
bab16a9b0c Merge pull request #8661 from jakevdp:ad-weak-types
PiperOrigin-RevId: 411930608
2021-11-23 17:30:30 -08:00
Peter Hawkins
83d8c6c238 Split slice/update_slice/gather/scatter out of jax._src.lax.lax into jax._src.lax.slicing.
To solve a circular dependency problem where some functions in jax._src.lax.lax depend on slicing, I moved a number of utility functions, e.g., standard_primitive, into a new module `jax._src.lax.utils`. Only utilities that need to be present at module import time were moved.

PiperOrigin-RevId: 411921794
2021-11-23 16:35:18 -08:00
Jake VanderPlas
496e400c71 [x64] Make autodiff respect weak types 2021-11-23 15:04:08 -08:00
jax authors
28b3c46b9b Merge pull request #8663 from jakevdp:stray-breakpoint
PiperOrigin-RevId: 411875485
2021-11-23 12:54:16 -08:00
Peter Hawkins
4204a25c91 Split convolution functions out of jax._src.lax.lax and into a separate module (jax._src.lax.convolution).
No public API changes.

PiperOrigin-RevId: 411871903
2021-11-23 12:35:50 -08:00
Jake VanderPlas
e14eaf0664 cleanup: remove stray debugging breakpoint 2021-11-23 12:17:08 -08:00
jax authors
ca443b5cb9 Merge pull request #8656 from froystig:backend-type-pmap-lowering
PiperOrigin-RevId: 411842480
2021-11-23 10:24:44 -08:00
Peter Hawkins
6cf5c4affb [XLA:CPU] Implement 3D convolutions using Eigen.
Eigen convolutions are much faster than the naive fallback IR.

[JAX] Relax jax2tf convolution test tolerance.

PiperOrigin-RevId: 411837376
2021-11-23 10:03:29 -08:00
jax authors
dd9afcfeb0 Merge pull request #8658 from gnecula:tf_arange_improve_error
PiperOrigin-RevId: 411824157
2021-11-23 09:04:45 -08:00
George Necula
ddc3a126e2 Improve error when jnp.arange is used with non-constant arguments 2021-11-23 16:19:31 +02:00
jax authors
2ec1488876 Merge pull request #8629 from jakevdp:dtypes-dtype
PiperOrigin-RevId: 411791488
2021-11-23 05:58:15 -08:00
George Necula
a1dee027c4 Disable recent change to wrap jax2tf lowered code with tf.function
Recent change: https://github.com/google/jax/pull/7839

PiperOrigin-RevId: 411776541
2021-11-23 04:26:48 -08:00
jax authors
db0a48ac8c Merge pull request #8657 from gnecula:tf_lint
PiperOrigin-RevId: 411751197
2021-11-23 01:53:36 -08:00
George Necula
72d9d35555 Fix lint and mypy errors 2021-11-23 11:17:37 +02:00
jax authors
055df6d9da Merge pull request #8653 from jakevdp:fix-x64-context
PiperOrigin-RevId: 411745165
2021-11-23 01:15:29 -08:00
jax authors
9781f365a1 Merge pull request #7839 from gnecula:tf_jit
PiperOrigin-RevId: 411742459
2021-11-23 00:56:49 -08:00
George Necula
43433078bc [jax2tf] Force TF compilation for code under jax.jit.
Previously, jax.jit was ignored by jax2tf. This can result in the
converted code being much slower than the JAX core, unless the
user adds an explicit `tf.function(jit_compile=True)`. With this
change that wrapper is added automatically for all code fragments
under jax.jit. Note that most jax.numpy functions are annotated
with jax.jit, so with this change they will all be compiled.

When doing this I ran into problems with tf.custom_gradient and
tf.function. As documented in the
[tf.custom_gradient](https://www.tensorflow.org/api_docs/python/tf/custom_gradient)
documentation, you get a LookupError when trying to build the gradient
of a tf.function, even if it has a tf.custom_gradient defined. The
recommended solution is to add a tf.stop_gradient. This is safe, since
jax2tf will always wrap the converted functions with a tf.custom_gradient.
2021-11-23 10:24:46 +02:00
Roy Frostig
9f83345784 complete annotation for XLA bridge functions that take a backend name or object 2021-11-22 18:22:11 -08:00
jax authors
0f0bc3ee14 Merge pull request #8449 from froystig:aot-pmap
PiperOrigin-RevId: 411684527
2021-11-22 17:29:43 -08:00
Jake VanderPlas
c4d9c4674f [x64] regularize dtype helpers 2021-11-22 15:35:12 -08:00
jax authors
b0d334a881 Merge pull request #8652 from hawkinsp:scatterdocs
PiperOrigin-RevId: 411657624
2021-11-22 15:07:56 -08:00
Jake VanderPlas
c5c78b5f6d [x64] make x64_context_test more robust 2021-11-22 14:54:30 -08:00
jax authors
6cc7d67484 Merge pull request #8648 from jakevdp:arange-dtype
PiperOrigin-RevId: 411640866
2021-11-22 13:54:44 -08:00
Peter Hawkins
361b7367cc Implement the select_and_gather_add translation rule via lower_fun.
This allows us to share the logic in the MLIR lowering.

PiperOrigin-RevId: 411639693
2021-11-22 13:49:48 -08:00
Roy Frostig
20a1517eeb factor tuple conversions into common pmap setup logic 2021-11-22 13:49:44 -08:00
Roy Frostig
cf64a945cf refine pmap-related annotations 2021-11-22 13:49:44 -08:00
Peter Hawkins
4679f455f9 Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.
This matches the documented behavior.

Fixes https://github.com/google/jax/issues/8634

PiperOrigin-RevId: 411635687
2021-11-22 13:32:58 -08:00
Peter Hawkins
5415306257 Make lax.reduce_window variadic.
This is similar to the support in lax.reduce(), where the operands and init_values become pytrees. This is a strict superset of the current API, so users should not need updates.

Variadic lax.reduce_window() is only supported on CPU and TPU at the moment, not GPU.

PiperOrigin-RevId: 411632993
2021-11-22 13:21:37 -08:00
Peter Hawkins
f3aa5fa92f Document lax.GatherScatterMode.
Recommend the .at[...] property in the docstrings for lax.scatter_ operators.

Add several missing lax.scatter_ operators to the index.
2021-11-22 15:43:02 -05:00
Peter Hawkins
ad6ce74d67 Skip some polar decomposition tests that fail on A100.
Works around https://github.com/google/jax/issues/8628

PiperOrigin-RevId: 411604717
2021-11-22 11:18:22 -08:00
Jake VanderPlas
52044556d0 [x64] avoid dtype conversions for arange arguments 2021-11-22 11:00:07 -08:00
Peter Hawkins
f4351e8419 Disable QDWH tests that fail on GPU and TPU.
PiperOrigin-RevId: 411591003
2021-11-22 10:21:41 -08:00
Peter Hawkins
dcded6a8f9 Fix incorrect gradient for base-dilated reduce window.
https://github.com/google/jax/pull/8606 introduced a runtime error where as a consequence of the move, a reference to `slice` became a reference to the builtin slice operator instead of `lax.slice`.

After fixing that and while added a test, I noticed that the gradient was wrong before: we should have been slicing the result, not the operand in the transpose rule's handling of base dilation.

Also enable some TPU tests that now pass since we have variadic reduce-window support on TPU.

PiperOrigin-RevId: 411579650
2021-11-22 09:34:10 -08:00
Roy Frostig
bf1dd3a848 refactor pmap staging, lowering, and compilation 2021-11-22 09:19:14 -08:00
Roy Frostig
3328fa48b8 rename backend to backend_name in parallel lowering 2021-11-22 09:19:14 -08:00
Roy Frostig
9f82d78007 typecheck pmap executable call arguments 2021-11-22 09:19:13 -08:00
Marc van Zee
6eca387582 Merge branch 'main' of http://www.github.com/google/jax into compare 2021-11-22 18:07:24 +01:00
Marc van Zee
7b25e05fd1 Improve logics for numerical comparison 2021-11-22 18:02:44 +01:00
Roy Frostig
fcdc0a6c1a ahead-of-time lowering and compilation frontend for pmap 2021-11-22 08:33:04 -08:00
Roy Frostig
8f88b89744 factor pmap compilation into lowering and compilation separately
Includes minor changes to mesh computation lowering/compilation, for
interface consistency.
2021-11-22 08:31:22 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
jax authors
34855def13 Merge pull request #8643 from google:gnecula-patch-1
PiperOrigin-RevId: 411558417
2021-11-22 07:47:11 -08:00
George Necula
263a7ff1b8
Update README.md 2021-11-22 17:17:09 +02:00
Peter Hawkins
fca3da51cd Switch most uses of jax.lax in jax2tf to use the public API.
The only reason jax2tf needs access to the internals of jax.lax is when it wants to reuse various translation rule helpers; keep those as explicit internal imports.

This change is partially to minimize churn as jax._src.lax is restructured.

PiperOrigin-RevId: 411276891
2021-11-20 10:13:26 -08:00