lax_scipy_test: Split into three targets, take 2.

The goal is to ensure that all shards fit into a medium timeout in sanitizer
configurations.

Running 256 entry vectors in spectral_dac is too slow, so let's replace that
with a smaller vector that isn't a power of 2. Avoiding a power of 2 requires
us to widen the tolerance a bit due to vectorization changes.

While here, specify deps a little more precisely as well.

PiperOrigin-RevId: 514440062
This commit is contained in:
pizzud 2023-03-06 09:49:52 -08:00 committed by jax authors
parent 6aed604789
commit ef28dcf091
4 changed files with 276 additions and 134 deletions

View File

@ -394,15 +394,13 @@ jax_test(
jax_test(
name = "lax_scipy_test",
srcs = ["lax_scipy_test.py"],
backend_tags = {
"tpu": ["noasan"], # Test times out.
},
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 40,
"cpu": 20,
"gpu": 20,
"tpu": 20,
"iree": 10,
},
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
)
jax_test(
@ -419,6 +417,32 @@ jax_test(
},
)
jax_test(
name = "lax_scipy_special_functions_test",
srcs = ["lax_scipy_special_functions_test.py"],
shard_count = {
"cpu": 20,
"gpu": 20,
"tpu": 20,
"iree": 10,
},
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
)
jax_test(
name = "lax_scipy_spectral_dac_test",
srcs = ["lax_scipy_spectral_dac_test.py"],
shard_count = {
"cpu": 40,
"gpu": 40,
"tpu": 40,
"iree": 40,
},
deps = [
"//jax:internal_test_util",
] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
)
jax_test(
name = "lax_test",
srcs = ["lax_test.py"],

View File

@ -0,0 +1,182 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
import itertools
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.special as osp_special
import jax
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"])
def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), test_name=None):
test_name = test_name or name
nondiff_argnums = tuple(sorted(set(nondiff_argnums)))
return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name)
float_dtypes = jtu.dtypes.floating
int_dtypes = jtu.dtypes.integer
# TODO(phawkins): we should probably separate out the function domains used for
# autodiff tests from the function domains used for equivalence testing. For
# example, logit should closely match its scipy equivalent everywhere, but we
# don't expect numerical gradient tests to pass for inputs very close to 0.
JAX_SPECIAL_FUNCTION_RECORDS = [
op_record(
"betaln", 2, float_dtypes, jtu.rand_positive, False
),
op_record(
"betainc", 3, float_dtypes, jtu.rand_positive, False
),
op_record(
"digamma", 1, float_dtypes, jtu.rand_positive, True
),
op_record(
"gammainc", 2, float_dtypes, jtu.rand_positive, True
),
op_record(
"gammaincc", 2, float_dtypes, jtu.rand_positive, True
),
op_record(
"erf", 1, float_dtypes, jtu.rand_small_positive, True
),
op_record(
"erfc", 1, float_dtypes, jtu.rand_small_positive, True
),
op_record(
"erfinv", 1, float_dtypes, jtu.rand_small_positive, True
),
op_record(
"expit", 1, float_dtypes, jtu.rand_small_positive, True
),
# TODO: gammaln has slightly high error.
op_record(
"gammaln", 1, float_dtypes, jtu.rand_positive, False
),
op_record(
"i0", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"i0e", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"i1", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"i1e", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"logit", 1, float_dtypes,
functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True),
op_record(
"log_ndtr", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"ndtri", 1, float_dtypes,
functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True,
),
op_record(
"ndtr", 1, float_dtypes, jtu.rand_default, True
),
# TODO(phawkins): gradient of entr yields NaNs.
op_record(
"entr", 1, float_dtypes, jtu.rand_default, False
),
op_record(
"polygamma", 2, (int_dtypes, float_dtypes),
jtu.rand_positive, True, (0,)),
op_record(
"xlogy", 2, float_dtypes, jtu.rand_positive, True
),
op_record(
"xlog1py", 2, float_dtypes, jtu.rand_default, True
),
# TODO: enable gradient test for zeta by restricting the domain of
# of inputs to some reasonable intervals
op_record("zeta", 2, float_dtypes, jtu.rand_positive, False),
# TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise
op_record(
"expi", 1, [np.float32],
functools.partial(jtu.rand_not_small, offset=0.1), True),
op_record("exp1", 1, [np.float32], jtu.rand_positive, True),
op_record(
"expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)),
]
class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.name, rng_factory=rec.rng_factory,
test_autodiff=rec.test_autodiff,
nondiff_argnums=rec.nondiff_argnums)],
shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs),
dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)),
)
for rec in JAX_SPECIAL_FUNCTION_RECORDS
))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes,
test_autodiff, nondiff_argnums):
scipy_op = getattr(osp_special, op)
lax_op = getattr(lsp_special, op)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
args = args_maker()
self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
check_dtypes=False)
self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)
if test_autodiff:
def partial_lax_op(*vals):
list_args = list(vals)
for i in nondiff_argnums:
list_args.insert(i, args[i])
return lax_op(*list_args)
assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums]
jtu.check_grads(partial_lax_op, diff_args, order=1,
atol=jtu.if_device_under_test("tpu", .1, 1e-3),
rtol=.1, eps=1e-3)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -0,0 +1,64 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from jax import lax
from jax import numpy as jnp
from jax._src import test_util as jtu
from jax._src.lax import eigh as lax_eigh
from absl.testing import absltest
from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS
linear_sizes = [16, 97, 128]
class LaxScipySpectralDacTest(jtu.JaxTestCase):
@jtu.sample_product(
linear_size=linear_sizes,
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
termination_size=[1, 19],
)
def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
if jtu.device_under_test() != "tpu" and termination_size != 1:
raise unittest.SkipTest(
"Termination sizes greater than 1 only work on TPU")
rng = self.rng()
H = rng.randn(linear_size, linear_size)
H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
self.assertRaises(
NotImplementedError, lax_eigh.eigh, H)
return
evs, V = lax_eigh.eigh(H, termination_size=termination_size)
ev_exp, _ = jnp.linalg.eigh(H)
HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
vV = evs.astype(V.dtype)[None, :] * V
eps = jnp.finfo(H.dtype).eps
atol = jnp.linalg.norm(H) * eps
self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
self.assertAllClose(
HV, vV, atol=atol * (140 if jnp.issubdtype(dtype, jnp.complexfloating)
else 40))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -13,14 +13,12 @@
# limitations under the License.
import collections
import functools
from functools import partial
import itertools
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.special as osp_special
@ -34,7 +32,6 @@ from jax.tree_util import tree_map
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special
from jax.scipy import cluster as lsp_cluster
from jax._src.lax import eigh as lax_eigh
from jax.config import config
config.parse_flags_with_absl()
@ -61,8 +58,6 @@ sides = ["right", "left"]
methods = ["qdwh", "svd"]
seeds = [1, 10]
linear_sizes = [16, 128, 256]
def _initialize_polar_test(rng, shape, n_zero_svs, degeneracy, geometric_spectrum,
max_sv, nonzero_condition_number, dtype):
@ -94,67 +89,10 @@ def _initialize_polar_test(rng, shape, n_zero_svs, degeneracy, geometric_spectru
spectrum = jnp.array(svs).astype(dtype)
return result, spectrum
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"])
def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), test_name=None):
test_name = test_name or name
nondiff_argnums = tuple(sorted(set(nondiff_argnums)))
return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name)
# TODO(phawkins): we should probably separate out the function domains used for
# autodiff tests from the function domains used for equivalence testing. For
# example, logit should closely match its scipy equivalent everywhere, but we
# don't expect numerical gradient tests to pass for inputs very close to 0.
JAX_SPECIAL_FUNCTION_RECORDS = [
op_record("betaln", 2, float_dtypes, jtu.rand_positive, False),
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
op_record("digamma", 1, float_dtypes, jtu.rand_positive, True),
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, True),
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, True),
op_record("erf", 1, float_dtypes, jtu.rand_small_positive, True),
op_record("erfc", 1, float_dtypes, jtu.rand_small_positive, True),
op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive, True),
op_record("expit", 1, float_dtypes, jtu.rand_small_positive, True),
# TODO: gammaln has slightly high error.
op_record("gammaln", 1, float_dtypes, jtu.rand_positive, False),
op_record("i0", 1, float_dtypes, jtu.rand_default, True),
op_record("i0e", 1, float_dtypes, jtu.rand_default, True),
op_record("i1", 1, float_dtypes, jtu.rand_default, True),
op_record("i1e", 1, float_dtypes, jtu.rand_default, True),
op_record("logit", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
high=0.95), True),
op_record("log_ndtr", 1, float_dtypes, jtu.rand_default, True),
op_record("ndtri", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
high=0.95),
True),
op_record("ndtr", 1, float_dtypes, jtu.rand_default, True),
# TODO(phawkins): gradient of entr yields NaNs.
op_record("entr", 1, float_dtypes, jtu.rand_default, False),
op_record("polygamma", 2, (int_dtypes, float_dtypes), jtu.rand_positive, True, (0,)),
op_record("xlogy", 2, float_dtypes, jtu.rand_positive, True),
op_record("xlog1py", 2, float_dtypes, jtu.rand_default, True),
# TODO: enable gradient test for zeta by restricting the domain of
# of inputs to some reasonable intervals
op_record("zeta", 2, float_dtypes, jtu.rand_positive, False),
# TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise
op_record("expi", 1, [np.float32], partial(jtu.rand_not_small, offset=0.1),
True),
op_record("exp1", 1, [np.float32], jtu.rand_positive, True),
op_record("expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)),
]
class LaxBackedScipyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Scipy implementation."""
def _GetArgsMaker(self, rng, shapes, dtypes):
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
@jtu.sample_product(
[dict(shapes=shapes, axis=axis, use_b=use_b)
for shape_group in compatible_shapes
@ -232,43 +170,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
result = lsp_special.logsumexp(1.0, b=1.0)
self.assertEqual(result, 1.0)
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
[dict(op=rec.name, rng_factory=rec.rng_factory,
test_autodiff=rec.test_autodiff,
nondiff_argnums=rec.nondiff_argnums)],
shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs),
dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)),
)
for rec in JAX_SPECIAL_FUNCTION_RECORDS
))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes,
test_autodiff, nondiff_argnums):
scipy_op = getattr(osp_special, op)
lax_op = getattr(lsp_special, op)
rng = rng_factory(self.rng())
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
args = args_maker()
self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
check_dtypes=False)
self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)
if test_autodiff:
def partial_lax_op(*vals):
list_args = list(vals)
for i in nondiff_argnums:
list_args.insert(i, args[i])
return lax_op(*list_args)
assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums]
jtu.check_grads(partial_lax_op, diff_args, order=1,
atol=jtu.if_device_under_test("tpu", .1, 1e-3),
rtol=.1, eps=1e-3)
@jtu.sample_product(
shape=all_shapes,
dtype=float_dtypes,
@ -581,34 +482,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertAllClose(
matrix, recon, atol=tol * jnp.linalg.norm(matrix))
@jtu.sample_product(
linear_size=linear_sizes,
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
termination_size=[1, 19],
)
def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
if jtu.device_under_test() != "tpu" and termination_size != 1:
raise unittest.SkipTest(
"Termination sizes greater than 1 only work on TPU")
rng = self.rng()
H = rng.randn(linear_size, linear_size)
H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
self.assertRaises(
NotImplementedError, lax_eigh.eigh, H)
return
evs, V = lax_eigh.eigh(H, termination_size=termination_size)
ev_exp, eV_exp = jnp.linalg.eigh(H)
HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
vV = evs.astype(V.dtype)[None, :] * V
eps = jnp.finfo(H.dtype).eps
atol = jnp.linalg.norm(H) * eps
self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
self.assertAllClose(
HV, vV, atol=atol * (140 if jnp.issubdtype(dtype, jnp.complexfloating)
else 30))
@jtu.sample_product(
n_obs=[1, 3, 5],
n_codes=[1, 2, 4],
@ -645,6 +518,5 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self.assertArraysEqual(actual, nan_array, check_dtypes=False)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())