Historically, tests that only ran on GPUs were placed in `OpsExtraTest`, while general tests were in `OpsTest`. However, this separation may cause us to miss issues that should be addressed on TPUs as well. Going forward, all tests will be unified in `OpsTest`, and any tests that fail on TPUs will be skipped individually using `skipTest`. This will help us better track and address TPU-specific failures.
PiperOrigin-RevId: 680747902
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.
PiperOrigin-RevId: 680706301
This fixes the failure in elementwise rule of apply vector layout pass.
If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.
If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.
PiperOrigin-RevId: 680674560
This changes makes it so that the refs users receive inside their kernels have shapes
matching their block specs. However, the refs are not actually plain refs, but transformed
references that begin with the fully transformed abstract ref and then stack the inverse
of the transformation stack on top of it. This means that all primitives that take in refs
can also see the sequence of transforms the user applied in the block spec, which lets us
verify e.g. that the inputs to WGMMA are correctly tiled, even though their user-visible
shape remains 2D. We should be able to use the same trick in the future to propagate tiling
and better infer the layouts for loads and stores.
PiperOrigin-RevId: 680520185
* Delete custom_object_test, since it is disabled and has been ever since jax.Array was introduced in JAX 0.4.0.
* custom_linear_solve_test was over-sharded, leading to some shards not having any test cases. Even unsharded it completes in under 65s on every platform we have.
* config_test and pallas splash attention mask test only tested helpers and didn't need a TPU.
PiperOrigin-RevId: 679711664
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".
We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.
Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.
PiperOrigin-RevId: 679563155
The preivous implementation made some surprising assumptions about the contents
of the block specs and wasn't correct in general. The new implementation handles
all the cases and seems to be sufficient to finally run the matmul example with
multiple k steps while producing correct results (it's also shorter!).
PiperOrigin-RevId: 679175212
We never checked if the output windows are done writing before we reused them.
Also, rename num_stages to max_concurrent_steps since we always only have 2 stages,
but might be running multiple iterations at a time.
Also fix the test for this that has been passing for reasons that I don't understand
(it didn't even write to all entries in the output??).
PiperOrigin-RevId: 679148961
I annotated a number of issues in the test. To make the test run I also needed to add support
for the accumulator reference allocation and discharge in the main lowering part. Ideally,
we'd defer it all to run_scoped, but run_scoped can't allocate barriers...
PiperOrigin-RevId: 679143948
* Changes the accumulator into a reference
* Creates a discharged flavor of the wgmma op
* run_scoped lowering discharges the input jaxpr
* dereferencing the accumulator ref is done by a new primitive that behaves as expected when discharged
* the deref primitive implies flushing the wgmma pipeline.
* run_scoped does not allow references to be leaked.
PiperOrigin-RevId: 679056765