346 Commits

Author SHA1 Message Date
Peter Hawkins
0b470361da Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
2022-04-28 06:01:56 -07:00
Matthew Johnson
fde6305012 refine const folding 2022-04-24 21:04:06 -07:00
Matthew Johnson
bf64f1843f [remove-units] prevent cond partial eval from introducing units 2022-04-24 14:28:33 -07:00
Eugene Burmako
0ed29b63f0 [MHLO] Add MHLO lowering for erf and erfc
erf implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=319-336;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443665435
2022-04-22 07:54:23 -07:00
Eugene Burmako
5f16873aad [MHLO] Switch tan to use CHLO lowering
Currently, it's desugared to sin(x)/cos(x) with upcast because CHLO_TanOp
legalization doesn't support complex numbers.

tan implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1175-1177;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443649394
2022-04-22 06:28:01 -07:00
Eugene Burmako
636345fd67 [MHLO] Add MHLO lowerings of remaining ops blocked by the lack of complex support in CHLO
The affected ops are: acosh, asinh and atanh
(in addition to cosh which was fixed a few days ago).

acosh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1181-1216;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

asinh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1218-1270;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

atanh implementation in the old bridge: https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1272-1292;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

PiperOrigin-RevId: 443590210
2022-04-22 00:39:02 -07:00
jax authors
48551be9c9 Merge pull request #10371 from sharadmv:mlir-token-lowering
PiperOrigin-RevId: 443458234
2022-04-21 13:08:21 -07:00
Sharad Vikram
f17c09eb8d add in mlir lowering for tokens 2022-04-21 11:28:58 -07:00
Matthew Johnson
e313428f2e remove lax._device_put_raw 2022-04-20 14:49:53 -07:00
Sharad Vikram
5ff2e8eb4c Fix name stack bugs 2022-04-19 11:14:41 -07:00
Peter Hawkins
21e1f8c3d1 [JAX] Delete last references to conv/dot translation rules.
Replace references with MHLO equivalents.

PiperOrigin-RevId: 442675847
2022-04-18 17:42:47 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Peter Hawkins
c4ba450867 [MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule.
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.

PiperOrigin-RevId: 442222533
2022-04-16 07:01:19 -07:00
Eugene Burmako
b267fd4336 [MHLO] Add MHLO lowering for cosh
Here's the corresponding old bridge lowering:
https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1294-1309;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

getConstantLike wasn't supporting complex numbers, and proper support required
non-trivial work, so in the meanwhile I've hacked up something that works
for static shapes to unblock the JAX use case (which currently only uses static shapes).

PiperOrigin-RevId: 442083258
2022-04-15 13:26:20 -07:00
Peter Hawkins
c2fe97ae01 Improve precision of chlo.sinh.
Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh

[JAX] Use chlo.sinh instead of the XLA client library HLO lowering.

PiperOrigin-RevId: 441851170
2022-04-14 14:10:26 -07:00
jax authors
6914e35af1 Merge pull request #10270 from mattjj:djax-iree
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Peter Hawkins
665df8dfaa [MHLO] Add an MHLO lowering for rng_bit_generator.
PiperOrigin-RevId: 441628987
2022-04-13 18:09:36 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
jax authors
b713d3ce4b Minor change to lax to support jax2tf shape polymorphic concatenation.
PiperOrigin-RevId: 440113799
2022-04-07 08:34:27 -07:00
Peter Hawkins
bc658e7456 [MHLO] Add direct MHLO lowerings for most linear algebra kernels.
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Peter Hawkins
b9bb61322c [MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings.
This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.

Add a handful of direct MLIR lowerings for primitives that lacked them.

PiperOrigin-RevId: 439912093
2022-04-06 12:53:56 -07:00
Peter Hawkins
ade9f1a294 Share compare_mhlo function between lax.py and mlir.py.
Use the .shape property on RankedTensorType.
2022-03-30 17:02:19 -04:00
Benjamin Kramer
a04b777c54 [mhlo] Clean up ops that can use InferTensorTypeWithReify
This means we can get rid of custom builders and return type inference. This
all goes through inferReturnTypeComponents now, so fix obvious bugs in those
implementations.

There should be no behaviorial change. However, python bindings no longer
generate a result type builder for mhlo.CompareOp, which is unfortunate.

PiperOrigin-RevId: 438341237
2022-03-30 10:44:16 -07:00
Ayaka Mikazuki
2799bb3cde
[doc] Fix typo 2022-03-29 21:29:51 +08:00
Sandeep Dasgupta
6cd9804163 Replace (deprecated) StrEnumAttr with EnumAttr.
ref: https://reviews.llvm.org/D120834
PiperOrigin-RevId: 435550738
2022-03-17 23:11:28 -07:00
Thomas Köppe
c3a4a6e63d Revert previous change
PiperOrigin-RevId: 435397906
2022-03-17 11:19:49 -07:00
Lena Martens
1d5833d2f1 Reshape top_k operand to 2D by collapsing the batch dimensions when lowering.
PiperOrigin-RevId: 435374934
2022-03-17 10:00:24 -07:00
Robert Suderman
97ddf986bc Make concatenate allow concatenation on dynamic dimensions
Concatenating two dynamic shapes together along those dynamic dimensions
should be allowed.

PiperOrigin-RevId: 434577959
2022-03-14 15:06:38 -07:00
Sharad Vikram
2988901e6c Refactor Jaxpr pretty-printing to use a JaxprPpSettings named tuple
and thread it into `pp_eqn_rules` so the settings are used recursively
2022-03-09 17:40:05 -08:00
jax authors
2a3f936ffa Merge pull request #9576 from nicholasjng:broadcast-validation
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73 Validate lax.broadcast_shape inputs before control flow execution
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.

In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Anselm Levskaya
f6a5f0dca2 Use real Precision type for lax.PrecisionType
PiperOrigin-RevId: 432413742
2022-03-04 04:21:25 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Roy Frostig
35fab1a95a err on repeated axes to expand_dims, as numpy does 2022-02-17 11:27:20 -08:00
Roy Frostig
0f7904f883 implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
Lena Martens
b15c7f609a Checkify: fix check_error of nd-error.
PiperOrigin-RevId: 428857813
2022-02-15 13:12:53 -08:00
Peter Hawkins
29c8a04527 Fix incorrect binary search comparison in lax.select_n lowering.
Fixes issue in https://github.com/google/jax/discussions/9556#discussioncomment-2175113
2022-02-14 14:29:38 -05:00
jax authors
0566ea4ccd Merge pull request #9456 from mattjj:jaxpr-pprint-color-flag-and-default
PiperOrigin-RevId: 428247626
2022-02-12 14:49:02 -08:00
Matthew Johnson
004bb684ea add flag for jaxpr colorful syntax highlighting
set it on by default
2022-02-12 14:15:28 -08:00
Peter Hawkins
8ca6622c0b Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select() to be a thin shim around the new n-ary version.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.

Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.

PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
Peter Hawkins
f94c42b271 Fix rendering of cube root in docs. 2022-02-04 11:30:10 -05:00
Sandeep Dasgupta
08fd83b68c [mhlo] jax infeed/outfeed lowering changes in response to detupling of the ops.
PiperOrigin-RevId: 425972002
2022-02-02 14:02:32 -08:00
Sandeep Dasgupta
c25d666539 [mhlo] Phase I: Remove tuples from mhlo Infeed/Outfeed/Send/Recv Ops.
The CL does the following:
1. Modify the Send/Recv op to disallow tuple as returns and arguments.
2. Modify the infeed/Outfeed op to allow both the tuple and variadic tensors as return and arguments.
   - Modify the exporter to XLA hlo to handle infeed/outfeed with both tuple and non-tuple variants.
   - Adjust the layout of infeed op during import and export.
   - This CL does not modify the tf lowering (new/old bridge) or jax lowering for infeed/outfeed op in this CL. They will still produce the tuple variant of the ops.  That is why the CL has logic to export both the variants.

Here is an example of the layout adjustments during import/export.
## Import
```
XLA HLO:   ROOT %infeed = ( ( s32[3, 4]{1, 0}, s32[5, 6]{1, 0}, s32[7, 8]{1, 0}, s32[9,10]{1, 0}), token[]) infeed(token[] %token0)
MHLO: "mhlo.infeed"([[TOKEN]]) layout = [[1, 0], [1, 0], [1, 0], [1, 0]] // A flattened list of data-layouts
```

## Export
```
MHLO: %0:3 = "mhlo.infeed"(%arg0) {layout=[[0, 1], [0]]} : (!mhlo.token) -> (tensor<3x3xi32>, tensor<i1>, !mhlo.token)

XLA HLO:

%INFEED = ((s32[3,3]{0,1}, pred[]), token[]) infeed(token[] %token)
%GTE1 = (s32[3,3]{0,1}, pred[]) get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) %INFEED), index=0 // accessing the data-tuple
GTE2 = s32[3,3]{0,1} get-tuple-element((s32[3,3]{0,1}, pred[]) %GTE1), index=0  // accessing the data
%GTE3 = pred[] get-tuple-element((s32[3,3]{0,1}, pred[]) %GTE1), index=1  // accessing the data
%GTE4 = token[] get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) %INFEED), index=1 // accessing the token
```
PiperOrigin-RevId: 425963519
2022-02-02 13:29:08 -08:00
jax authors
e29bfce85c Merge pull request #9415 from hawkinsp:intpow
PiperOrigin-RevId: 425889721
2022-02-02 08:22:02 -08:00
Peter Hawkins
f7ccda5a69 Work around call expansion in MHLO->HLO lowering.
Don't cache and outline lowerings of `x**2` and `x**3`.
2022-02-02 11:02:38 -05:00