Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.
The previous behavior can be approximated using the "clip" or "wrap" fill modes.
PiperOrigin-RevId: 445130143
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.
PiperOrigin-RevId: 442222533
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):
```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.
Add a handful of direct MLIR lowerings for primitives that lacked them.
PiperOrigin-RevId: 439912093
This means we can get rid of custom builders and return type inference. This
all goes through inferReturnTypeComponents now, so fix obvious bugs in those
implementations.
There should be no behaviorial change. However, python bindings no longer
generate a result type builder for mhlo.CompareOp, which is unfortunate.
PiperOrigin-RevId: 438341237
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.
In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.
Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.
PiperOrigin-RevId: 427517899
The CL does the following:
1. Modify the Send/Recv op to disallow tuple as returns and arguments.
2. Modify the infeed/Outfeed op to allow both the tuple and variadic tensors as return and arguments.
- Modify the exporter to XLA hlo to handle infeed/outfeed with both tuple and non-tuple variants.
- Adjust the layout of infeed op during import and export.
- This CL does not modify the tf lowering (new/old bridge) or jax lowering for infeed/outfeed op in this CL. They will still produce the tuple variant of the ops. That is why the CL has logic to export both the variants.
Here is an example of the layout adjustments during import/export.
## Import
```
XLA HLO: ROOT %infeed = ( ( s32[3, 4]{1, 0}, s32[5, 6]{1, 0}, s32[7, 8]{1, 0}, s32[9,10]{1, 0}), token[]) infeed(token[] %token0)
MHLO: "mhlo.infeed"([[TOKEN]]) layout = [[1, 0], [1, 0], [1, 0], [1, 0]] // A flattened list of data-layouts
```
## Export
```
MHLO: %0:3 = "mhlo.infeed"(%arg0) {layout=[[0, 1], [0]]} : (!mhlo.token) -> (tensor<3x3xi32>, tensor<i1>, !mhlo.token)
XLA HLO:
%INFEED = ((s32[3,3]{0,1}, pred[]), token[]) infeed(token[] %token)
%GTE1 = (s32[3,3]{0,1}, pred[]) get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) %INFEED), index=0 // accessing the data-tuple
GTE2 = s32[3,3]{0,1} get-tuple-element((s32[3,3]{0,1}, pred[]) %GTE1), index=0 // accessing the data
%GTE3 = pred[] get-tuple-element((s32[3,3]{0,1}, pred[]) %GTE1), index=1 // accessing the data
%GTE4 = token[] get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) %INFEED), index=1 // accessing the token
```
PiperOrigin-RevId: 425963519