Native lowering requires that dimension variables be computable from the shapes of the arguments that are kept by lowering. We had some tests that were using dimension variables but were not using the actual inputs. Then lowering removes the inputs and it is not possible anymore to recover the values of the dimension variables at invocation time.
Here we primarily changed the tests to ensure they use not just the shape of the input but the actual value. In some cases we have disabled some testing, until
https://github.com/google/jax/issues/14437 is fixed
PiperOrigin-RevId: 509171805
It seems to me that jaxlib development must be mostly happening on CI, because some basics are pretty essential. Here are a few things I've been typing/carrying for a while in my flow:
* Add .bazelrc.user to .gitignore so it doesn't accidentally get checked in.
* Add configs for 'debug_symbols' and 'debug' that make some things minimally workable under a debugger (or to get backtraces, etc).
* Add `--force-reinstall` to the copy/paste command to update a built jaxlib wheel (without this, if you are iterating, it fairly quietly does nothing).
Before:
```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```
After:
```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
Add more checks to catch early the cases when the called TF function
returns values that are not convertible to JAX values (arrays of
numeric values). All these cases were resulting in errors even before
but sometimes these errors were deep in the stack and harder to
diagnose.
Also:
* fixed some of the tests that were using the shape but not the value
of the input arguments
* fix importing of mlir.py due to recent move of interpreters.mlir to
_src.interpreters.mlir