diff --git a/tests/BUILD b/tests/BUILD index df9a28236..134e7a4ba 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -82,9 +82,13 @@ jax_multiplatform_test( }, ) -jax_multiplatform_test( +jax_py_test( name = "config_test", srcs = ["config_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ] + py_deps("absl/testing"), ) jax_multiplatform_test( @@ -96,11 +100,6 @@ jax_multiplatform_test( }, ) -jax_multiplatform_test( - name = "custom_object_test", - srcs = ["custom_object_test.py"], -) - jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], @@ -397,11 +396,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, ) jax_multiplatform_test( diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py deleted file mode 100644 index 4b1182e16..000000000 --- a/tests/custom_object_test.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -import math -import unittest - -import numpy as np - -import jax -import jax.numpy as jnp -from jax import jit, lax, make_jaxpr -from jax.interpreters import mlir -from jax.interpreters import xla - -from jax._src import core -from jax._src import dtypes -from jax._src import test_util as jtu -from jax._src import xla_bridge -from jax._src.lib.mlir import ir -from jax._src.lib import xla_client - -xc = xla_client -xb = xla_bridge - -jax.config.parse_flags_with_absl() - -# TODO(jakevdp): use a setup/teardown method to populate and unpopulate all the -# dictionaries associated with the following objects. - -# Define a sparse array data structure. The important feature here is that -# it is a jaxpr object that is backed by two device buffers. -class SparseArray: - """Simple sparse COO array data structure.""" - def __init__(self, aval, data, indices): - self.aval = aval - self.shape = aval.shape - self.data = data - self.indices = indices - - @property - def index_dtype(self): - return self.indices.dtype - - @property - def dtype(self): - return self.data.dtype - - @property - def nnz(self): - return self.data.shape[0] - - def __repr__(self): - return repr([(tuple(ind), d) for ind, d in zip(self.indices, self.data)]) - - -class AbstractSparseArray(core.ShapedArray): - __slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval'] - - def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False): - super().__init__(shape, dtypes.canonicalize_dtype(dtype)) - self.index_dtype = index_dtype - self.nnz = nnz - self.data_aval = core.ShapedArray( - (nnz,), dtypes.canonicalize_dtype(dtype), weak_type) - self.indices_aval = core.ShapedArray( - (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype)) - - def update(self, shape=None, dtype=None, index_dtype=None, nnz=None, - weak_type=None): - if shape is None: - shape = self.shape - if dtype is None: - dtype = self.dtype - if index_dtype is None: - index_dtype = self.dtype - if nnz is None: - nnz = self.nnz - if weak_type is None: - weak_type = self.weak_type - return AbstractSparseArray(shape, dtype, index_dtype, nnz, weak_type) - - def strip_weak_type(self): - return self - - @core.aval_property - def data(self): - return sp_data_p.bind(self) - - @core.aval_property - def indices(self): - return sp_indices_p.bind(self) - -class ConcreteSparseArray(AbstractSparseArray): - pass - - -def sparse_array_shape_handler(a): - return ( - xc.Shape.array_shape(a.data_aval.dtype, a.data_aval.shape), - xc.Shape.array_shape(a.indices_aval.dtype, a.indices_aval.shape), - ) - - -core.pytype_aval_mappings[SparseArray] = lambda x: x.aval -core.raise_to_shaped_mappings[AbstractSparseArray] = lambda aval, _: aval -xla.pytype_aval_mappings[SparseArray] = lambda x: x.aval -xla.canonicalize_dtype_handlers[SparseArray] = lambda x: x - -def sparse_array_mlir_type_handler(a): - return ( - ir.RankedTensorType.get( - a.data_aval.shape, mlir.dtype_to_ir_type(a.data_aval.dtype)), - ir.RankedTensorType.get( - a.indices_aval.shape, mlir.dtype_to_ir_type(a.indices_aval.dtype)), - ) - -mlir.ir_type_handlers[AbstractSparseArray] = sparse_array_mlir_type_handler - -sp_indices_p = core.Primitive('sp_indices') - -@sp_indices_p.def_impl -def _sp_indices_impl(mat): - return mat.indices - -@sp_indices_p.def_abstract_eval -def _sp_indices_abstract_eval(mat): - return mat.indices_aval - -# Note: cannot use lower_fun to define attribute access primitives -# because it leads to infinite recursion. - -def _sp_indices_hlo_lowering(ctx, data_and_indices): - return [data_and_indices[1]] - -mlir.register_lowering(sp_indices_p, _sp_indices_hlo_lowering) - -sp_data_p = core.Primitive('sp_data') - -@sp_data_p.def_impl -def _sp_data_impl(mat): - return mat.data - -@sp_data_p.def_abstract_eval -def _sp_data_abstract_eval(mat): - return mat.data_aval - -# Note: cannot use lower_fun to define attribute access primitives -# because it leads to infinite recursion. - -def _sp_data_hlo_lowering(ctx, data_and_indices): - return [data_and_indices[0]] - -mlir.register_lowering(sp_data_p, _sp_data_hlo_lowering) - -def identity(x): - return identity_p.bind(x) - -identity_p = core.Primitive('identity') - -@identity_p.def_impl -def _identity_impl(mat): - return mat - -@identity_p.def_abstract_eval -def _identity_abstract_eval(mat): - return AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) - -mlir.register_lowering( - identity_p, mlir.lower_fun(_identity_impl, multiple_results=False)) - -def split(x): - return split_p.bind(x) - -split_p = core.Primitive('split') -split_p.multiple_results = True - -@split_p.def_impl -def _split_impl(mat): - return mat, mat - -@split_p.def_abstract_eval -def _split_abstract_eval(mat): - m = AbstractSparseArray(mat.shape, mat.dtype, mat.index_dtype, mat.nnz) - return m, m - -mlir.register_lowering( - split_p, mlir.lower_fun(_split_impl, multiple_results=True)) - -def make_sparse_array(rng, shape, dtype, nnz=0.2): - mat = rng(shape, dtype) - size = math.prod(shape) - if 0 < nnz < 1: - nnz = nnz * size - nnz = int(nnz) - if nnz == 0: - mat = np.zeros_like(mat) - elif nnz < size: - # TODO(jakevdp): do we care about duplicates? - cutoff = np.sort(mat.ravel())[nnz] - mat[mat >= cutoff] = 0 - nz = (mat != 0) - data = jnp.array(mat[nz]) - indices = jnp.array(np.where(nz)).T - aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices)) - return SparseArray(aval, data, indices) - -def matvec(mat, v): - v = jnp.asarray(v) - assert v.ndim == 1 - assert len(mat.shape) == 2 - assert v.shape[0] == mat.shape[1] - rows = mat.indices[:, 0] - cols = mat.indices[:, 1] - dv = mat.data * v[cols] - return jnp.zeros(mat.shape[0], dtype=dv.dtype).at[rows].add(dv) - - -class Empty: - def __init__(self, aval): - self.aval = aval - -class AbstractEmpty(core.AbstractValue): - - def join(self, other): - assert isinstance(other, self.__class__), other - return self - - def __hash__(self): - return hash(()) - - def __eq__(self, other): - return isinstance(other, AbstractEmpty) - -class ConcreteEmpty(AbstractEmpty): - pass - - -core.pytype_aval_mappings[Empty] = lambda x: ConcreteEmpty() -core.raise_to_shaped_mappings[AbstractEmpty] = lambda aval, _: aval -xla.pytype_aval_mappings[Empty] = lambda x: AbstractEmpty() -xla.canonicalize_dtype_handlers[Empty] = lambda x: x - - -@unittest.skip("Test does not work with jax.Array") -class CustomObjectTest(jtu.JaxTestCase): - - @jtu.sample_product( - primitive=[True, False], - compile=[True, False], - ) - def testSparseIdentity(self, compile, primitive): - f = identity if primitive else (lambda x: x) - f = jit(f) if compile else f - rng = jtu.rand_default(self.rng()) - M = make_sparse_array(rng, (10,), jnp.float32) - M2 = f(M) - - jaxpr = make_jaxpr(f)(M).jaxpr - core.check_jaxpr(jaxpr) - - self.assertEqual(M.dtype, M2.dtype) - self.assertEqual(M.index_dtype, M2.index_dtype) - self.assertAllClose(M.data, M2.data) - self.assertAllClose(M.indices, M2.indices) - - @jtu.sample_product( - compile=[True, False], - ) - def testSparseSplit(self, compile): - f = jit(split) if compile else split - rng = jtu.rand_default(self.rng()) - M = make_sparse_array(rng, (10,), jnp.float32) - M2, M3 = f(M) - - jaxpr = make_jaxpr(f)(M).jaxpr - core.check_jaxpr(jaxpr) - - for MM in M2, M3: - self.assertEqual(M.dtype, MM.dtype) - self.assertEqual(M.index_dtype, MM.index_dtype) - self.assertArraysEqual(M.data, MM.data) - self.assertArraysEqual(M.indices, MM.indices) - - @jtu.sample_product( - primitive=[True, False], - compile=[True, False], - ) - def testSparseLaxLoop(self, compile, primitive): - rng = jtu.rand_default(self.rng()) - f = identity if primitive else (lambda x: x) - f = jit(f) if compile else f - body_fun = lambda _, A: f(A) - M = make_sparse_array(rng, (10,), jnp.float32) - lax.fori_loop(0, 10, body_fun, M) - - @jtu.sample_product(attr=["data", "indices"]) - def testSparseAttrAccess(self, attr): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)] - f = lambda x: getattr(x, attr) - self._CompileAndCheck(f, args_maker) - - @jtu.sample_product( - shape=[(3, 3), (2, 6), (6, 2)], - dtype=jtu.dtypes.floating, - ) - def testSparseMatvec(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)] - self._CompileAndCheck(matvec, args_maker) - - def testLowerToNothing(self): - empty = Empty(AbstractEmpty()) - jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr - core.check_jaxpr(jaxpr) - - # cannot return a unit, because CompileAndCheck assumes array output. - testfunc = lambda e: None - args_maker = lambda: [empty] - self._CompileAndCheck(testfunc, args_maker) - - def testConstantHandler(self): - def make_const_array(): - data = np.arange(3.0) - indices = np.arange(3)[:, None] - shape = (5,) - aval = AbstractSparseArray(shape, data.dtype, indices.dtype, len(indices)) - return SparseArray(aval, data, indices) - out1 = make_const_array() - out2 = jit(make_const_array)() - self.assertArraysEqual(out1.data, out2.data) - self.assertArraysEqual(out1.indices, out2.indices) - - -if __name__ == '__main__': - absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 044f82067..0867d6920 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -16,6 +16,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", "jax_multiplatform_test", + "jax_py_test", "py_deps", ) @@ -438,17 +439,16 @@ jax_multiplatform_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_multiplatform_test( +# This test doesn't need a TPU; it only tests numpy-using helpers. +jax_py_test( name = "tpu_splash_attention_mask_test", srcs = [ "tpu_splash_attention_mask_test.py", ], - enable_backends = [ - "cpu", - "tpu", - ], deps = [ + "//jax", "//jax:pallas_tpu_ops", + "//jax:test_util", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), )