1807 Commits

Author SHA1 Message Date
YouJiacheng
b485b8e5ce implement scipy.cluster.vq.vq
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
2022-04-23 03:14:32 +08:00
Eugene Burmako
0ed29b63f0 [MHLO] Add MHLO lowering for erf and erfc
erf implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=319-336;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443665435
2022-04-22 07:54:23 -07:00
Eugene Burmako
5f16873aad [MHLO] Switch tan to use CHLO lowering
Currently, it's desugared to sin(x)/cos(x) with upcast because CHLO_TanOp
legalization doesn't support complex numbers.

tan implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1175-1177;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443649394
2022-04-22 06:28:01 -07:00
Eugene Burmako
636345fd67 [MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO
The affected ops are: acosh, asinh and atanh
(in addition to cosh which was fixed a few days ago).

acosh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1181-1216;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

asinh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1218-1270;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

atanh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1272-1292;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443590210
2022-04-22 00:39:02 -07:00
jax authors
5013bd2e3a Merge pull request #10402 from froystig:aot-jit-avoid-trivial
PiperOrigin-RevId: 443533232
2022-04-21 18:13:10 -07:00
Aart Bik
fb370b86ff Adds a ability to pass computation directly as module to backend
PiperOrigin-RevId: 443512012
2022-04-21 16:38:16 -07:00
Roy Frostig
5c118071cb always lower/compile computations on the AOT jit path
... even trivial ones.
2022-04-21 15:30:36 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
jax authors
48551be9c9 Merge pull request #10371 from sharadmv:mlir-token-lowering
PiperOrigin-RevId: 443458234
2022-04-21 13:08:21 -07:00
Aart Bik
c1261ccd27 Adds a wrapper to sparse tensor dialect, as part of an
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.

PiperOrigin-RevId: 443438043
2022-04-21 11:48:44 -07:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00
jax authors
bef5e02816 Merge pull request #10150 from jakevdp:normalize-unsigned
PiperOrigin-RevId: 443416771
2022-04-21 10:34:35 -07:00
Peter Hawkins
74346f464b [JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.
Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics.

PiperOrigin-RevId: 443390170
2022-04-21 08:52:14 -07:00
Jake VanderPlas
92ca76a039 Skip normalization of unsigned indices 2022-04-20 16:04:12 -07:00
Matthew Johnson
e313428f2e remove lax._device_put_raw 2022-04-20 14:49:53 -07:00
Tianjian Lu
455c9f823e [linalg] Adds full_matrices option to TPU SVD.
PiperOrigin-RevId: 443163571
2022-04-20 12:32:00 -07:00
YouJiacheng
af7b94b110 Fix typo of #10381
and add a basic regression test
2022-04-21 02:17:09 +08:00
jax authors
e9588ab43f Merge pull request #10381 from YouJiacheng:remove-numpy.linalg._promote_arg_dtypes
PiperOrigin-RevId: 443133279
2022-04-20 10:37:04 -07:00
Peter Hawkins
4fd824c36f Change jnp.take_along_axis to require that its indices are of integer type.
Previously jnp.take_along_axis silently casted its indices to integers if they were not already integers.

PiperOrigin-RevId: 443124521
2022-04-20 10:05:16 -07:00
jax authors
ea0233b995 Merge pull request #10379 from hawkinsp:onehot
PiperOrigin-RevId: 443122227
2022-04-20 09:57:05 -07:00
jax authors
4dce508b2a Merge pull request #10377 from YouJiacheng:patch-8
PiperOrigin-RevId: 443119505
2022-04-20 09:45:05 -07:00
jax authors
606d1c2402 Merge pull request #10385 from jakevdp:promote-types-comment
PiperOrigin-RevId: 443116925
2022-04-20 09:39:46 -07:00
jax authors
b2132f7884 Merge pull request #10358 from LenaMartens:changelist/442788164
PiperOrigin-RevId: 443116900
2022-04-20 09:34:22 -07:00
YouJiacheng
bb2682db6d remove numpy.linalg._promote_arg_dtypes
in favor of numpy.util._promote_dtypes_inexact
2022-04-21 00:23:56 +08:00
Jake VanderPlas
2588c98586 Add comment explaining implementation in promote_types 2022-04-20 08:44:49 -07:00
jax authors
a03be320d1 Merge pull request #10368 from hawkinsp:takealongaxis2
PiperOrigin-RevId: 443076407
2022-04-20 06:22:29 -07:00
Peter Hawkins
6e6f693e6d Use lax.broadcasted_iota in jax.nn.one_hot.
Minor cleanup that means we emit one fewer MHLO op, no functional changes intended.
2022-04-20 09:09:19 -04:00
YouJiacheng
f6ca60ec29
DOC: lax.linalg.eigh
Fix the inconsistency of variable name between docstring and source code.
Add description of eigenvalues
2022-04-20 16:23:16 +08:00
jax authors
27d32e4a74 Merge pull request #10367 from mattjj:10366
PiperOrigin-RevId: 442909079
2022-04-19 14:17:38 -07:00
Matthew Johnson
6f606a0b57 fix issue #10366 2022-04-19 13:18:00 -07:00
Peter Hawkins
a52f07a21b Add an optional mode= argument to jnp.take_along_axis.
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
2022-04-19 16:07:00 -04:00
jax authors
7008b32132 Merge pull request #10296 from sharadmv:jax2tf-name-stack
PiperOrigin-RevId: 442872933
2022-04-19 11:57:19 -07:00
Peter Hawkins
e1b606934f Temporarily revert: Change default jnp.take_along_axis gather mode to "fill".
Some tests were broken by the change; reverting this PR for the moment while debugging the problem.

PiperOrigin-RevId: 442868210
2022-04-19 11:39:12 -07:00
Tianjian Lu
5a1c5ba114 [linalg] Adds compute_uv to TPU SVD.
PiperOrigin-RevId: 442864883
2022-04-19 11:28:43 -07:00
Sharad Vikram
5ff2e8eb4c Fix name stack bugs 2022-04-19 11:14:41 -07:00
Peter Hawkins
7c73bfbc46 Change default jnp.take_along_axis gather mode to "fill".
PiperOrigin-RevId: 442817397
2022-04-19 08:24:24 -07:00
jax authors
b8971b9f28 Reapply: fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
Prefer `jnp.tile` over `concatenate`

PiperOrigin-RevId: 442803459
2022-04-19 07:12:27 -07:00
Lena Martens
e4836f5663 Checkify: support checks on data-independent values.
You can now check values which do not depend on checkified args.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-04-19 15:06:00 +01:00
jax authors
6b99ee6e48 Merge pull request #10341 from jakevdp:take-doc
PiperOrigin-RevId: 442707713
2022-04-18 21:22:09 -07:00
jax authors
fc2a12c478 Temporarily revert fff370d78d107ed81431becf9dfe97eba77863fb by Lukas Geiger <lukas.geiger94@gmail.com>:
Prefer `jnp.tile` over `concatenate`

PiperOrigin-RevId: 442693096
2022-04-18 19:34:30 -07:00
Peter Hawkins
21e1f8c3d1 [JAX] Delete last references to conv/dot translation rules.
Replace references with MHLO equivalents.

PiperOrigin-RevId: 442675847
2022-04-18 17:42:47 -07:00
Gain Hagenau
59d8b8d6b2 Remove flags set for all v4 TPUs. Topology flags will now be set in libTPU.
Remove deprecated fields `TPU_MESH_CONTROLLER_ADDRESS` and `TPU_MESH_CONTROLLER_PORT`.

PiperOrigin-RevId: 442663216
2022-04-18 16:39:34 -07:00
jax authors
f6705fc269 Merge pull request #10221 from lgeiger:concat-tile
PiperOrigin-RevId: 442587085
2022-04-18 11:19:07 -07:00
jax authors
a3f3af3ac7 Merge pull request #10287 from YouJiacheng:patch-6
PiperOrigin-RevId: 442585443
2022-04-18 11:12:20 -07:00
Jake VanderPlas
437f942b1a jnp.take: add documentation for mode parameter default 2022-04-18 10:12:30 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Lukas Geiger
fff370d78d Prefer jnp.tile over concatenate 2022-04-18 10:55:30 +01:00
jax authors
c0fe9256c1 Merge pull request #10313 from google:docs7
PiperOrigin-RevId: 442311111
2022-04-16 22:42:39 -07:00
Peter Hawkins
c4ba450867 [MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule.
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.

PiperOrigin-RevId: 442222533
2022-04-16 07:01:19 -07:00