Lambdas are represented by their ids in the metadata of lowered HLO (see example below) and they change every time. This makes the compilation cache less effective as it causes the computation's fingerprint to change every time.
```
get-tuple-element.41724 = bf16[8]{0} get-tuple-element(reduce.41723), index=0, metadata={op_name="pjit(_wrapped_fn)/jit(main)/.../reduce[computation=<function _compute_argminmax.<locals>.reducer_fn at 0x7fa6ecfb2200> dimensions=(1,)]" source_file="..." source_line=...}
```
PiperOrigin-RevId: 601910715
We don't need to support `isinstance(..., PRNGKeyArray)` on tracers any longer, since `PRNGKeyArray` is no longer a public symbol.
PiperOrigin-RevId: 601815616
The block argument of tt.reduce is always parameterized by scalars.
Note that this bug had no effect on the emitted Triton IR, because the
lowering code does not currently rely on avals.
PiperOrigin-RevId: 601801294
Why? We've found in practice that downstream projects use fold_in multiple
times with the same key. This is safe so long as the folded-in value is
different every time; in this sense fold_in() is similar to seed(), and
for now we must trust the user to not repeat seeds.
LoadModuleFromData has (data, format, config, ...) signature while FromFile has (path, config, format, ...). Change the latter so `format` becomes the second argument in both cases.
Since I'm touching this file:
* Use `std::string_view` and `absl::Status`
* Change `ovr_config` parameter to `const &`
PiperOrigin-RevId: 601304308
We are migrating some attrs on some StableHLO ops to use DenseI64ArrayAttr instead of DenseIntElementsAttr. Using DenseI64ArrayAttr enforces that the attr values are 1-dimensional and provides nicer APIs. (see https://github.com/openxla/stablehlo/issues/1578 for additional context)
Unfortunately, we have to duplicate the `dense_int_array` function because we migrated the ops in batches. We can't use the existing `dense_int_array` function because it would produce arrays for ops that hadn't yet been migrated. This PR makes the final batch of changes, so no additional methods should be added going forward.
We also have to introduce a new `dense_bool_array` function, with a similar version check.
When the minimum supported jaxlib version uses a recent enough version of StableHLO (v6 or above), it will be possible to remove the version checks and remove the duplicated `dense_int_array_v6` function.
PiperOrigin-RevId: 601271749
Note that all primitives are now lowered to libdevice calls. Previously,
some of them were lowered to the MLIR arith dialect, and some to libdevice
calls, without any apparent reason for doing so.
PiperOrigin-RevId: 601259707
As I explore more powerful ways to reason about inequalities,
I came up with more tests of inequalities that I wish we can handle.
This PR adds the tests I have so far, even if they do not produce
the correct result yet. I write the expected values for tests as
_expect(best=v1, current=v2)
to document that the current logic produces `v2` but the best value
we can hope for is `v1`.
This PR also adds more support for profiling tests.