7 Commits

Author SHA1 Message Date
Peter Hawkins
dda428e74a Disable tests that trigger warning if x64 mode isn't enabled. 2024-05-10 19:58:22 +00:00
Sergei Lebedev
4b62425b42 Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least
This makes it clear that the predicate is only supposed to be used for NVidia
GPUs at the moment.
2024-05-08 21:41:50 +01:00
Sergei Lebedev
575ba942e0 Removed get_compute_capability from jax.experimental.pallas.gpu
Compute capability is available as a `str` attribute on a GPU device since
jaxlib 0.4.26.
2024-05-08 21:10:43 +01:00
Sergei Lebedev
0feeaa5999 Removed stale version guards and try/except blocks from Pallas GPU
They are unnecessary now that the minimum jaxlib version is 0.4.27.
2024-05-08 17:05:45 +01:00
Sergey Kozub
8738e7e5dc Use correct kWidth in sparse dots with int8 input (on Ampere)
PiperOrigin-RevId: 628368832
2024-04-26 04:53:25 -07:00
Jake VanderPlas
beb49af678 sparse_nm_test: skip on incompatible GPUs
PiperOrigin-RevId: 628120697
2024-04-25 10:38:07 -07:00
Sergey Kozub
aebe82a78f Add JAX API that provides sparse matmul support (2:4 structured sparsity)
Usage:
from jax.experimental.sparse import nm
res = nm.nm_spmm(lhs, rhs, nm.nm_pack(mask))

where:
lhs.shape = [M, K/2]
rhs.shape = [K, N]
`mask` has the same shape as `lhs` with boolean type

If batch dimensions are present, the `dimension_numbers` argument has to be set to:
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

The lowering only works on nVidia GPUs, that provide hardware support for sparse dots.

PiperOrigin-RevId: 627640553
2024-04-24 01:06:19 -07:00