Speed up jvp's for 3x3 and 2x2 determinants
The current det implementation custom_jvp is all encompassing, so while there's fast functions for the 2 and 3d cases they still go via the slow general jvp. PR localises the custom_jvp to the generic case.
This general case is ~10x slower on GPU (A100) and ~250x slower on TPU (v2).
```python
import jax
from jax import numpy as jnp
from jax import random
def det_3x3(a: jax.Array) -> jax.Array:
return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
a[..., 0, 2] * a[..., 1, 1] * a[..., 2, 0] -
a[..., 0, 0] * a[..., 1, 2] * a[..., 2, 1] -
a[..., 0, 1] * a[..., 1, 0] * a[..., 2, 2])
key = random.key(42)
x = random.normal(key, (int(1e5), 3, 3))
general_grad = jax.grad(lambda x: jnp.linalg.det(x).sum())
direct_3by3_grad = jax.vmap(jax.grad(det_3x3))
general_grad, direct_3by3_grad = (jax.jit(f) for f in (general_grad, direct_3by3_grad))
_ = jax.block_until_ready(general_grad(x))
_ = jax.block_until_ready(direct_3by3_grad(x))
%timeit _ = jax.block_until_ready(general_grad(x))
%timeit _ = jax.block_until_ready(direct_3by3_grad(x))
directly instead of defining aliasing for arrays with potentially incompatible
layouts. We only fallback to xla dontation for exactly those arrays which
have input and output layouts explicitly set to conflicting layouts.
PiperOrigin-RevId: 667770224
Add `--@local_config_cuda//cuda:override_include_cuda_libs` to override settings for TF wheel.
Forbid building TF wheel with `--@local_config_cuda//cuda:include_cuda_libs=true`
PiperOrigin-RevId: 666848518
Tests fixed include:
- `test_globally_sharded_key_array_8x4_multi_device`
- Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
- Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
- In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
- This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
- This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
- This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.
PiperOrigin-RevId: 666777167