diff --git a/jax/experimental/sparse/nm.py b/jax/experimental/sparse/nm.py new file mode 100644 index 000000000..251bf45f0 --- /dev/null +++ b/jax/experimental/sparse/nm.py @@ -0,0 +1,241 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""N:M-sparsity associated primitives.""" + +from jax import core +from jax._src import dispatch +from jax._src.lax.lax import DotDimensionNumbers +from jax._src.lib import gpu_sparse +from jax._src.lib.mlir.dialects import mhlo +from jax._src.typing import Array, DTypeLike +from jax.interpreters import mlir +import jax.numpy as jnp +import numpy as np + +# -------------------------------------------------------------------- +# nm_spmm + +nm_spmm_p = core.Primitive("sparse_dense_matmul") + +_supported_input_types = (jnp.int8, jnp.int16, jnp.float16, jnp.bfloat16) +_supported_output_types = (jnp.bfloat16, jnp.float32) + + +def nm_spmm( + lhs: Array, + rhs: Array, + metadata: Array, + dimension_numbers: DotDimensionNumbers = (((1,), (0,)), (tuple(), tuple())), + sparse_operand_idx: int = 0, + output_dtype: DTypeLike = jnp.bfloat16, +) -> Array: + """Dot operation where one of the operands has N:M sparsity. + + Args: + lhs: An ndarray (first dot operand). + rhs: An ndarray (second dot operand). + metadata: An ndarray with structured sparsity metadata for the contracting + dimension. For 2:4 sparsity it should contain (N=2) two-bit index values + for each (M=4) element group. + dimension_numbers: a tuple of tuples of the form `((lhs_contracting_dims, + rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`. + sparse_operand_idx: index of the sparse operand (0 or 1). + output_dtype: result type. + + Returns: + An ndarray dense array containing the result. + """ + return nm_spmm_p.bind( + lhs, + rhs, + metadata, + dimension_numbers=dimension_numbers, + sparse_operand_idx=sparse_operand_idx, + output_dtype=output_dtype, + ) + + +def _calc_groups_per_element(n, m): + group_bits = n * (m.bit_length() - 1) # 4 bits per group for 2:4 + return 16 // group_bits + + +def _validate_dnums(rank, contract, batch, name): + non_contract = tuple(sorted(set(range(rank)) - set(contract + batch))) + if sorted(non_contract + contract + batch) != list(range(rank)): + raise TypeError(f"Incorrect dimension numbers for {name}") + return non_contract + + +def _validate_metadata(lhs, rhs, metadata, dimension_numbers, index, n=2, m=4): + assert index in (0, 1) + size_factor = n * _calc_groups_per_element(n, m) + + sparse = [lhs, rhs][index] + sparse_contract = dimension_numbers[0][index] + if metadata.dtype != np.uint16: + raise TypeError(f"Metadata must be uint16, got {metadata.dtype}") + if sparse_contract[0] != sparse.ndim - 1: + raise TypeError("Contracting dimension must be the minor one") + if metadata.shape[:-1] != sparse.shape[:-1]: + raise TypeError( + "Metadata shape must match the operand shape (except for the" + " contracting dimension)" + ) + if metadata.shape[-1] * size_factor != sparse.shape[-1]: + raise TypeError( + f"Metadata must be exactly {size_factor} times less than the" + f" contracting dimension for {n}:{m} structured sparsity (expected" + f" {sparse.shape[-1] // size_factor}, got {metadata.shape[-1]})" + ) + if sparse.shape[-1] % size_factor != 0: + raise NotImplementedError("Metadata with padding is not supported") + + dense = [lhs, rhs][1 - index] + dense_contract = dimension_numbers[0][1 - index] + a, b = sparse.shape[sparse_contract[0]], dense.shape[dense_contract[0]] + if n * b != m * a: + raise TypeError( + f"Contracting dimension sizes should have {n}:{m} ratio, got {a}:{b}" + ) + + +def _infer_result_shape(lhs, rhs, dimension_numbers): + ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) = dimension_numbers + if len(lhs_contract) != 1 or len(rhs_contract) != 1: + raise TypeError("Only single contracting dimension is supported") + lhs_dims = _validate_dnums(lhs.ndim, lhs_contract, lhs_batch, "lhs") + rhs_dims = _validate_dnums(rhs.ndim, rhs_contract, rhs_batch, "rhs") + if len(lhs_dims) != 1 or len(rhs_dims) != 1: + raise TypeError("Only single non-contracting dimension is supported") + batch = [lhs.shape[i] for i in lhs_batch] + if batch != [rhs.shape[i] for i in rhs_batch]: + raise TypeError("Batch dimension sizes do not match") + return tuple(batch + [lhs.shape[lhs_dims[0]], rhs.shape[rhs_dims[0]]]) + + +def _nm_spmm_default_lowering(*_args, **_kwargs): + raise NotImplementedError("Sparse N:M matmul is only implemented on GPU") + + +def _nm_spmm_gpu_lowering( + ctx, + lhs, + rhs, + metadata, + *, + dimension_numbers, + sparse_operand_idx, + output_dtype, +): + assert sparse_operand_idx in (0, 1) + sparsity_descriptor = mhlo.SparsityDescriptor.get( + dimension=dimension_numbers[0][sparse_operand_idx][0], n=2, m=4 + ) + dot_dnums = mhlo.DotDimensionNumbers.get( + lhs_batching_dimensions=dimension_numbers[1][sparse_operand_idx], + rhs_batching_dimensions=dimension_numbers[1][1 - sparse_operand_idx], + lhs_contracting_dimensions=dimension_numbers[0][sparse_operand_idx], + rhs_contracting_dimensions=dimension_numbers[0][1 - sparse_operand_idx], + ) + dot_type = ctx.avals_out[0] + key = ["lhs_sparsity", "rhs_sparsity"][sparse_operand_idx] + kwargs = {key: sparsity_descriptor} + op = mhlo.SparseDotOp( + mlir.aval_to_ir_type(dot_type), lhs, rhs, [metadata], dot_dnums, **kwargs + ) + return op.results + + +@nm_spmm_p.def_abstract_eval +def _nm_spmm_abstract_eval( + lhs, rhs, metadata, *, dimension_numbers, sparse_operand_idx, output_dtype +): + if lhs.dtype not in _supported_input_types: + raise TypeError(f"Unsupported lhs input type: {lhs.dtype}") + if rhs.dtype not in _supported_input_types: + raise TypeError(f"Unsupported rhs input type: {rhs.dtype}") + if output_dtype not in _supported_output_types: + raise TypeError(f"Unsupported output type: {output_dtype}") + + res_shape = _infer_result_shape(lhs, rhs, dimension_numbers) + _validate_metadata(lhs, rhs, metadata, dimension_numbers, sparse_operand_idx) + return core.ShapedArray(res_shape, output_dtype) + + +mlir.register_lowering(nm_spmm_p, _nm_spmm_default_lowering) +dispatch.simple_impl(nm_spmm_p) + +if gpu_sparse.cuda_is_supported: + mlir.register_lowering(nm_spmm_p, _nm_spmm_gpu_lowering, platform="cuda") + +# -------------------------------------------------------------------- +# nm_pack + +nm_pack_p = core.Primitive("sparse_pack_nm") + + +def nm_pack(mask: Array, n=2, m=4) -> Array: + """Generate metadata tensor for an N:M mask. + + Args: + mask: Predicates for the input tensor, where the elements are grouped in the + minor dimension. In each group of size M there should be exactly N true + values, which mark the data elements to keep. + n: Number of non-zero elements in a group. + m: Group size. + + Returns: + An ndarray containing only the masked input elements. + """ + return nm_pack_p.bind(mask, n=n, m=m) + + +def _compress(data, n, m, k): + result = [] + expected = n * (k // m) + for i in range(0, len(data), k): + index = tuple(jnp.nonzero(data[i : i + k], size=expected)[0] % m) + value = sum(j * pow(m, i) for i, j in enumerate(index)) + result.append(value) + return jnp.array(result, dtype=np.uint16) + + +@nm_pack_p.def_impl +def _nm_pack_impl(mask, *, n, m): + batch_size = m * _calc_groups_per_element(n, m) + return jnp.apply_along_axis( + lambda x: _compress(x, n, m, batch_size), -1, mask + ) + + +@nm_pack_p.def_abstract_eval +def _nm_pack_abstract_eval(mask, *, n, m): + size_factor = m * _calc_groups_per_element(n, m) + if mask.dtype != bool: + raise TypeError(f"Mask should be bool, got {mask.dtype}") + if mask.shape[-1] % size_factor != 0: + raise TypeError( + f"Inner dimension size should be divisible by {size_factor}, got" + f" {mask.shape}" + ) + res_shape = list(mask.shape) + res_shape[-1] //= size_factor + return core.ShapedArray(res_shape, np.uint16) + + +_nm_pack_lowering = mlir.lower_fun(_nm_pack_impl, multiple_results=False) +mlir.register_lowering(nm_pack_p, _nm_pack_lowering) +dispatch.simple_impl(nm_pack_p) diff --git a/tests/BUILD b/tests/BUILD index 02cce6b31..987883cd1 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -984,6 +984,23 @@ jax_test( ] + py_deps("scipy"), ) +jax_test( + name = "sparse_nm_test", + srcs = ["sparse_nm_test.py"], + disable_backends = [ + "cpu", + "gpu", + "tpu", + ], + enable_configs = [ + "gpu_a100", + "gpu_h100", + ], + deps = [ + "//jax:experimental_sparse", + ], +) + jax_test( name = "sparsify_test", srcs = ["sparsify_test.py"], diff --git a/tests/sparse_nm_test.py b/tests/sparse_nm_test.py new file mode 100644 index 000000000..ae4455145 --- /dev/null +++ b/tests/sparse_nm_test.py @@ -0,0 +1,200 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from absl.testing import absltest +from absl.testing import parameterized + +import jax +from jax import dtypes +from jax._src import test_util as jtu +from jax.experimental.sparse import nm +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + + +class SpmmTest(jtu.JaxTestCase): + # ----- Test different input shapes + @parameterized.product( + tile_m=(32, 128), + tile_n=(32, 128), + tile_k=(32, 128), + batch=(None, 5), + sparse_idx=(0, 1), + ) + @jtu.run_on_devices("gpu") + def test_shapes(self, tile_m, tile_n, tile_k, batch, sparse_idx): + # Build keyword arguments + kwargs = { + "dimension_numbers": (((1,), (1,)), (tuple(), tuple())), + "sparse_operand_idx": sparse_idx, + } + if batch: + kwargs["dimension_numbers"] = (((2,), (2,)), ((0,), (0,))) + + # Build input data + batch_dims = (batch,) if batch else tuple() + lhs = ( + (np.arange((batch or 1) * tile_m * tile_k) % 11) + .astype(dtypes.bfloat16) + .reshape(batch_dims + (tile_m, tile_k)) + ) + rhs = ( + (np.arange((batch or 1) * tile_n * tile_k) % 13) + .astype(dtypes.bfloat16) + .reshape(batch_dims + (tile_n, tile_k)) + ) + + # Build sparsity mask and metadata + sp = [lhs, rhs][sparse_idx] + mask = np.tile([True, False], math.prod(sp.shape) // 2).reshape(sp.shape) + sparse = sp[mask].reshape(sp.shape[:-1] + (sp.shape[-1] // 2,)) + meta = nm.nm_pack(mask) + + # Calculate sparse and dense dots + if sparse_idx == 0: + dot_sparse = nm.nm_spmm(sparse, rhs, meta, **kwargs) + dot_dense = jnp.einsum("...mk,...nk->...mn", (lhs * mask), rhs) + else: + dot_sparse = nm.nm_spmm(lhs, sparse, meta, **kwargs) + dot_dense = jnp.einsum("...mk,...nk->...mn", lhs, (rhs * mask)) + + # Verify the result + jtu.check_eq(dot_sparse, dot_dense.astype(dtypes.bfloat16)) + + # ----- Test different input types + # TODO(b/336519663): add int8 type once codegen is fixed + @parameterized.product( + lhs_type=[jnp.int16, jnp.float16, jnp.bfloat16], + rhs_type=[jnp.bfloat16], + output_type=[jnp.bfloat16, jnp.float32], + ) + @jtu.run_on_devices("gpu") + def test_types(self, lhs_type, rhs_type, output_type): + tile_m, tile_n, tile_k = 64, 32, 128 + + # Build input data + lhs = ( + (np.arange(tile_m * tile_k) % 17) + .astype(lhs_type) + .reshape((tile_m, tile_k)) + ) + rhs = ( + (np.arange(tile_k * tile_n) % 19) + .astype(rhs_type) + .reshape((tile_k, tile_n)) + ) + + # Build sparsity mask and metadata + mask = np.tile([True, False], tile_m * tile_k // 2).reshape(lhs.shape) + sparse = lhs[mask].reshape(tile_m, tile_k // 2) + meta = nm.nm_pack(mask) + + # Calculate sparse and dense dots + dot_sparse = nm.nm_spmm(sparse, rhs, meta, output_dtype=output_type) + dot_dense = (lhs * mask) @ rhs + + # Verify the result + jtu.check_close(dot_sparse, dot_dense.astype(output_type), rtol=0.01) + + # ----- Test validation + @jtu.run_on_devices("gpu") + def test_validate_nm_pack(self): + with self.assertRaisesRegex(TypeError, "Mask should be bool"): + nm.nm_pack(jnp.zeros(16, jnp.int8)) + with self.assertRaisesRegex( + TypeError, "Inner dimension size should be divisible by 16" + ): + nm.nm_pack(jnp.array([False] * 8)) + + @jtu.run_on_devices("gpu") + def test_validate_nm_spmm(self): + batch, tile_m, tile_n, tile_k = 2, 64, 32, 128 + lhs = jnp.zeros((batch, tile_m, tile_k // 2), dtype=jnp.bfloat16) + rhs = jnp.zeros((batch, tile_k, tile_n), dtype=jnp.bfloat16) + meta = jnp.zeros((batch, tile_m, tile_k // 16), dtype=jnp.uint16) + + # Check types + with self.assertRaisesRegex(TypeError, "Unsupported lhs input type"): + nm.nm_spmm(jnp.zeros(lhs.shape, dtype=jnp.int64), rhs, meta) + with self.assertRaisesRegex(TypeError, "Unsupported rhs input type"): + nm.nm_spmm(lhs, jnp.zeros(rhs.shape, dtype=jnp.int64), meta) + with self.assertRaisesRegex(TypeError, "Unsupported output type"): + nm.nm_spmm(lhs, rhs, meta, output_dtype=jnp.int64) + + # Check dimension numbers + nm_spmm_with_dnums = lambda c, b: nm.nm_spmm( + lhs, rhs, meta, dimension_numbers=(c, b) + ) + with self.assertRaisesRegex( + TypeError, "Only single contracting dimension is supported" + ): + nm_spmm_with_dnums(((0, 2), (0, 1)), (tuple(), tuple())) + with self.assertRaisesRegex( + TypeError, "Incorrect dimension numbers for lhs" + ): + nm_spmm_with_dnums(((2,), (1,)), ((2,), (0,))) + with self.assertRaisesRegex( + TypeError, "Incorrect dimension numbers for rhs" + ): + nm_spmm_with_dnums(((2,), (1,)), ((0,), (1,))) + with self.assertRaisesRegex( + TypeError, "Only single non-contracting dimension is supported" + ): + nm_spmm_with_dnums(((2,), (1,)), (tuple(), tuple())) + with self.assertRaisesRegex( + TypeError, "Batch dimension sizes do not match" + ): + nm.nm_spmm( + lhs, + rhs.reshape(1, tile_k, tile_n * batch), + meta, + dimension_numbers=(((2,), (1,)), ((0,), (0,))), + ) + + # Check metadata + nm_spmm_with_meta = lambda m: nm.nm_spmm( + lhs, rhs, m, dimension_numbers=(((2,), (1,)), ((0,), (0,))) + ) + with self.assertRaisesRegex(TypeError, "Metadata must be uint16"): + nm_spmm_with_meta(jnp.zeros(meta.shape, dtype=jnp.uint8)) + with self.assertRaisesRegex( + TypeError, "Metadata shape must match the operand shape" + ): + nm_spmm_with_meta(meta.reshape(1, batch * tile_m, tile_k // 16)) + with self.assertRaisesRegex( + TypeError, + "Metadata must be exactly 8 times less than the contracting dimension" + " for 2:4 structured sparsity", + ): + nm_spmm_with_meta(jnp.repeat(meta, 2, axis=-1)) + with self.assertRaisesRegex( + TypeError, "Contracting dimension must be the minor one" + ): + nm.nm_spmm(lhs, rhs, meta, dimension_numbers=(((1,), (1,)), ((0,), (0,)))) + with self.assertRaisesRegex( + TypeError, "Contracting dimension sizes should have 2:4 ratio" + ): + nm.nm_spmm( + lhs, + jnp.repeat(rhs, 2, axis=1), + meta, + dimension_numbers=(((2,), (1,)), ((0,), (0,))), + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())