617 Commits

Author SHA1 Message Date
Peter Hawkins
0b470361da Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
2022-04-28 06:01:56 -07:00
Matthew Johnson
4608d36340 add scan dce rule 2022-04-27 20:47:43 -07:00
Peter Hawkins
7c6a550333 Change the default scatter mode to FILL_OR_DROP.
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.

This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.

After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.

Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.

Issues: https://github.com/google/jax/issues/278 https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
2022-04-27 12:26:55 -07:00
jax authors
1bac12af88 Merge pull request #10451 from jakevdp:qr-jvp
PiperOrigin-RevId: 444885520
2022-04-27 09:43:32 -07:00
Matthew Johnson
bd00926b63 don't bind scan on jaxpr_known if no outputs 2022-04-26 21:37:35 -07:00
Matthew Johnson
ebbad07ce7 [remove-units] roll forward #10448, fix dce bug 2022-04-26 20:34:14 -07:00
Matthew Johnson
823ad552d6 Copybara import of the project:
--
4680b86ff7f468429a0820b4f8c7f64ffd1a1cad by Matthew Johnson <mattjj@google.com>:

[remove-units] prevent scan partial eval from introducing units

PiperOrigin-RevId: 444698613
2022-04-26 16:29:54 -07:00
jax authors
c45bf20c38 Merge pull request #10448 from mattjj:remove-units-scan
PiperOrigin-RevId: 444624012
2022-04-26 11:51:54 -07:00
Jake VanderPlas
67e0fdda82 lax.linalg.qr: allow jvp when m == n and full_matrices=True 2022-04-26 10:34:50 -07:00
Matthew Johnson
4680b86ff7 [remove-units] prevent scan partial eval from introducing units 2022-04-26 08:51:00 -07:00
Anudhyan Boral
a147046d18 Add unary xeinsum and allow named axis reductions for unary and binary xeinsums 2022-04-26 09:55:42 +00:00
Matthew Johnson
9359cc3e53 [remove-units] remove units from while partial eval 2022-04-24 21:29:48 -07:00
Matthew Johnson
fde6305012 refine const folding 2022-04-24 21:04:06 -07:00
Matthew Johnson
bf64f1843f [remove-units] prevent cond partial eval from introducing units 2022-04-24 14:28:33 -07:00
Matthew Johnson
221680fed5 remove old cond todos 2022-04-22 22:36:45 -07:00
YouJiacheng
75e990bbc3 Fix typo in _scatter_add_lower_gpu
a87b21148c doesn't notice `_scatter_add_lower_gpu` using `mlir.lower_fun` instead of `xla.lower_fun`.
I follow the change done in that commit for _scatter_lower.
2022-04-22 23:55:11 +08:00
Eugene Burmako
0ed29b63f0 [MHLO] Add MHLO lowering for erf and erfc
erf implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=319-336;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443665435
2022-04-22 07:54:23 -07:00
Eugene Burmako
5f16873aad [MHLO] Switch tan to use CHLO lowering
Currently, it's desugared to sin(x)/cos(x) with upcast because CHLO_TanOp
legalization doesn't support complex numbers.

tan implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1175-1177;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443649394
2022-04-22 06:28:01 -07:00
Eugene Burmako
636345fd67 [MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO
The affected ops are: acosh, asinh and atanh
(in addition to cosh which was fixed a few days ago).

acosh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1181-1216;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

asinh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1218-1270;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

atanh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1272-1292;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443590210
2022-04-22 00:39:02 -07:00
jax authors
48551be9c9 Merge pull request #10371 from sharadmv:mlir-token-lowering
PiperOrigin-RevId: 443458234
2022-04-21 13:08:21 -07:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00
Matthew Johnson
e313428f2e remove lax._device_put_raw 2022-04-20 14:49:53 -07:00
Tianjian Lu
455c9f823e [linalg] Adds full_matrices option to TPU SVD.
PiperOrigin-RevId: 443163571
2022-04-20 12:32:00 -07:00
YouJiacheng
f6ca60ec29
DOC: lax.linalg.eigh
Fix the inconsistency of variable name between docstring and source code.
Add description of eigenvalues
2022-04-20 16:23:16 +08:00
jax authors
7008b32132 Merge pull request #10296 from sharadmv:jax2tf-name-stack
PiperOrigin-RevId: 442872933
2022-04-19 11:57:19 -07:00
Tianjian Lu
5a1c5ba114 [linalg] Adds compute_uv to TPU SVD.
PiperOrigin-RevId: 442864883
2022-04-19 11:28:43 -07:00
Sharad Vikram
5ff2e8eb4c Fix name stack bugs 2022-04-19 11:14:41 -07:00
Peter Hawkins
21e1f8c3d1 [JAX] Delete last references to conv/dot translation rules.
Replace references with MHLO equivalents.

PiperOrigin-RevId: 442675847
2022-04-18 17:42:47 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Peter Hawkins
c4ba450867 [MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule.
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.

PiperOrigin-RevId: 442222533
2022-04-16 07:01:19 -07:00
Eugene Burmako
b267fd4336 [MHLO] Add MHLO lowering for cosh
Here's the corresponding old bridge lowering:
https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1294-1309;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

getConstantLike wasn't supporting complex numbers, and proper support required
non-trivial work, so in the meanwhile I've hacked up something that works
for static shapes to unblock the JAX use case (which currently only uses static shapes).

PiperOrigin-RevId: 442083258
2022-04-15 13:26:20 -07:00
Parker Schuh
58ef6cb9e1
Merge branch 'main' into opt-barrier-gpu 2022-04-14 19:31:12 -07:00
Peter Hawkins
c2fe97ae01 Improve precision of chlo.sinh.
Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh

[JAX] Use chlo.sinh instead of the XLA client library HLO lowering.

PiperOrigin-RevId: 441851170
2022-04-14 14:10:26 -07:00
jax authors
6914e35af1 Merge pull request #10270 from mattjj:djax-iree
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Peter Hawkins
6c1461b52b [MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
PiperOrigin-RevId: 441769591
2022-04-14 08:38:21 -07:00
Peter Hawkins
4806c29bf7 [MHLO] Add MHLO lowerings for FFT ops.
PiperOrigin-RevId: 441768017
2022-04-14 08:31:17 -07:00
Parker Schuh
b0c41b9023 Switch gpu, cpu and jax2tf to use the new OptimizationBarrier op.
This change should unblock the removal of the cond and while widgets.
2022-04-13 22:51:36 -07:00
Peter Hawkins
665df8dfaa [MHLO] Add an MHLO lowering for rng_bit_generator.
PiperOrigin-RevId: 441628987
2022-04-13 18:09:36 -07:00
Peter Hawkins
21f95d531b Remove use of xla.lower_fun in SVD translation rule.
This is the only use of xla.lower_fun that is still needed (as a fallback) when the non-MHLO path is removed.

PiperOrigin-RevId: 441538472
2022-04-13 11:44:45 -07:00
Jake VanderPlas
1a8c57d272 better errors: check for callability of lax.control_flow arguments 2022-04-13 10:39:01 -07:00
Jake VanderPlas
7bfc86e17f Fix arguments to schur translation rule 2022-04-13 09:50:33 -07:00
Peter Hawkins
cb4abe754a [MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
2022-04-13 07:26:26 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Sharad Vikram
4392b07022 Add tests for higher order primitives 2022-04-12 18:12:44 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
jax authors
0bfb3efcd7 [JAX] Fix batch logic for approx_min/max_k
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.

PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00