11258 Commits

Author SHA1 Message Date
YouJiacheng
b485b8e5ce implement scipy.cluster.vq.vq
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
2022-04-23 03:14:32 +08:00
Eugene Burmako
0ed29b63f0 [MHLO] Add MHLO lowering for erf and erfc
erf implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=319-336;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443665435
2022-04-22 07:54:23 -07:00
Eugene Burmako
5f16873aad [MHLO] Switch tan to use CHLO lowering
Currently, it's desugared to sin(x)/cos(x) with upcast because CHLO_TanOp
legalization doesn't support complex numbers.

tan implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1175-1177;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443649394
2022-04-22 06:28:01 -07:00
Eugene Burmako
636345fd67 [MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO
The affected ops are: acosh, asinh and atanh
(in addition to cosh which was fixed a few days ago).

acosh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1181-1216;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

asinh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1218-1270;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

atanh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1272-1292;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443590210
2022-04-22 00:39:02 -07:00
jax authors
5013bd2e3a Merge pull request #10402 from froystig:aot-jit-avoid-trivial
PiperOrigin-RevId: 443533232
2022-04-21 18:13:10 -07:00
Xin Zhou
bd077817dc [mhlo] Add result type inference for mhlo.broadcast.
PiperOrigin-RevId: 443527300
2022-04-21 17:40:21 -07:00
Aart Bik
fb370b86ff Adds a ability to pass computation directly as module to backend
PiperOrigin-RevId: 443512012
2022-04-21 16:38:16 -07:00
jax authors
60223fb5f1 Merge pull request #10408 from jakevdp:xmap-flake8
PiperOrigin-RevId: 443511304
2022-04-21 16:32:23 -07:00
Jake VanderPlas
2a183cb638 Apply flake8 checks to xmap_test.py 2022-04-21 16:06:53 -07:00
jax authors
5ccb8a33e9 Merge pull request #10406 from jakevdp:noqa
PiperOrigin-RevId: 443496817
2022-04-21 15:31:58 -07:00
Roy Frostig
5c118071cb always lower/compile computations on the AOT jit path
... even trivial ones.
2022-04-21 15:30:36 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
jax authors
f1104cf8d5 Merge pull request #10389 from jakevdp:deprecated-test-util
PiperOrigin-RevId: 443462414
2022-04-21 13:24:08 -07:00
jax authors
48551be9c9 Merge pull request #10371 from sharadmv:mlir-token-lowering
PiperOrigin-RevId: 443458234
2022-04-21 13:08:21 -07:00
Jake VanderPlas
d9508304e4 Deprecate remaining functionality in jax.test_util 2022-04-21 12:12:40 -07:00
Aart Bik
c1261ccd27 Adds a wrapper to sparse tensor dialect, as part of an
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.

PiperOrigin-RevId: 443438043
2022-04-21 11:48:44 -07:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00
jax authors
bef5e02816 Merge pull request #10150 from jakevdp:normalize-unsigned
PiperOrigin-RevId: 443416771
2022-04-21 10:34:35 -07:00
Peter Hawkins
74346f464b [JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.
Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics.

PiperOrigin-RevId: 443390170
2022-04-21 08:52:14 -07:00
Jake VanderPlas
92ca76a039 Skip normalization of unsigned indices 2022-04-20 16:04:12 -07:00
jax authors
0fc93a0a99 Merge pull request #10392 from mattjj:remove-device-put-raw
PiperOrigin-RevId: 443209240
2022-04-20 15:28:18 -07:00
Matthew Johnson
e313428f2e remove lax._device_put_raw 2022-04-20 14:49:53 -07:00
James Bradbury
38e754585f [mesh_utils] Add device/slice count checks
PiperOrigin-RevId: 443178279
2022-04-20 13:30:01 -07:00
Tianjian Lu
455c9f823e [linalg] Adds full_matrices option to TPU SVD.
PiperOrigin-RevId: 443163571
2022-04-20 12:32:00 -07:00
jax authors
29d54e3297 Merge pull request #10387 from YouJiacheng:fix-typo-of-#10381
PiperOrigin-RevId: 443154137
2022-04-20 11:52:35 -07:00
YouJiacheng
af7b94b110 Fix typo of #10381
and add a basic regression test
2022-04-21 02:17:09 +08:00
jax authors
e9588ab43f Merge pull request #10381 from YouJiacheng:remove-numpy.linalg._promote_arg_dtypes
PiperOrigin-RevId: 443133279
2022-04-20 10:37:04 -07:00
Peter Hawkins
4fd824c36f Change jnp.take_along_axis to require that its indices are of integer type.
Previously jnp.take_along_axis silently casted its indices to integers if they were not already integers.

PiperOrigin-RevId: 443124521
2022-04-20 10:05:16 -07:00
jax authors
ea0233b995 Merge pull request #10379 from hawkinsp:onehot
PiperOrigin-RevId: 443122227
2022-04-20 09:57:05 -07:00
jax authors
4dce508b2a Merge pull request #10377 from YouJiacheng:patch-8
PiperOrigin-RevId: 443119505
2022-04-20 09:45:05 -07:00
jax authors
606d1c2402 Merge pull request #10385 from jakevdp:promote-types-comment
PiperOrigin-RevId: 443116925
2022-04-20 09:39:46 -07:00
jax authors
b2132f7884 Merge pull request #10358 from LenaMartens:changelist/442788164
PiperOrigin-RevId: 443116900
2022-04-20 09:34:22 -07:00
YouJiacheng
bb2682db6d remove numpy.linalg._promote_arg_dtypes
in favor of numpy.util._promote_dtypes_inexact
2022-04-21 00:23:56 +08:00
Jake VanderPlas
2588c98586 Add comment explaining implementation in promote_types 2022-04-20 08:44:49 -07:00
jax authors
a03be320d1 Merge pull request #10368 from hawkinsp:takealongaxis2
PiperOrigin-RevId: 443076407
2022-04-20 06:22:29 -07:00
Peter Hawkins
6e6f693e6d Use lax.broadcasted_iota in jax.nn.one_hot.
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
Lena Martens
762c9b4774
Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-04-20 10:46:33 +01:00
YouJiacheng
f6ca60ec29
DOC: lax.linalg.eigh
Fix the inconsistency of variable name between docstring and source code.
Add description of eigenvalues
2022-04-20 16:23:16 +08:00
jax authors
88ac4edf63 Merge pull request #10362 from JeppeKlitgaard:fix-random-reexport
PiperOrigin-RevId: 442936248
2022-04-19 16:04:31 -07:00
jax authors
79c435042b Merge pull request #10298 from jakevdp:bcoo-sum-duplicates
PiperOrigin-RevId: 442932036
2022-04-19 15:46:30 -07:00
Jake VanderPlas
edae0ac31f [sparse] make bcoo_sum_duplicates a primitive 2022-04-19 15:34:49 -07:00
jax authors
c930e593d2 Merge pull request #10364 from dbisk:patch-1
PiperOrigin-RevId: 442924977
2022-04-19 15:17:17 -07:00
jax authors
27d32e4a74 Merge pull request #10367 from mattjj:10366
PiperOrigin-RevId: 442909079
2022-04-19 14:17:38 -07:00
Peter Hawkins
1bed70590a [MHLO] Switch call_tf to use an MHLO lowering (attempt 2).
In passing refactor and fix some bugs in the MHLO helper code:
* mlir.ir_constant() failed to propagate its canonicalize_types argument to its callee.
* Refactor the code to convert an XLA computation to an MHLO module and to merge two MHLO modules from the XLA fallback translation rule path.
* Fix symbol (alpha) renaming of call operator callees when merging MHLO modules.

Attempt 2: In this iteration of the merge_mhlo_modules function, move all the operators into the target module first before doing any symbol table manipulation.

PiperOrigin-RevId: 442904129
2022-04-19 14:00:00 -07:00
Matthew Johnson
6f606a0b57 fix issue #10366 2022-04-19 13:18:00 -07:00
Peter Hawkins
a52f07a21b Add an optional mode= argument to jnp.take_along_axis.
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
2022-04-19 16:07:00 -04:00
jax authors
7008b32132 Merge pull request #10296 from sharadmv:jax2tf-name-stack
PiperOrigin-RevId: 442872933
2022-04-19 11:57:19 -07:00
Peter Hawkins
e1b606934f Temporarily revert: Change default jnp.take_along_axis gather mode to "fill".
Some tests were broken by the change; reverting this PR for the moment while debugging the problem.

PiperOrigin-RevId: 442868210
2022-04-19 11:39:12 -07:00
Tianjian Lu
5a1c5ba114 [linalg] Adds compute_uv to TPU SVD.
PiperOrigin-RevId: 442864883
2022-04-19 11:28:43 -07:00
Sharad Vikram
5ff2e8eb4c Fix name stack bugs 2022-04-19 11:14:41 -07:00