mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
lax_scipy_test: Split into three targets, take 2.
The goal is to ensure that all shards fit into a medium timeout in sanitizer configurations. Running 256 entry vectors in spectral_dac is too slow, so let's replace that with a smaller vector that isn't a power of 2. Avoiding a power of 2 requires us to widen the tolerance a bit due to vectorization changes. While here, specify deps a little more precisely as well. PiperOrigin-RevId: 514440062
This commit is contained in:
parent
6aed604789
commit
ef28dcf091
36
tests/BUILD
36
tests/BUILD
@ -394,15 +394,13 @@ jax_test(
|
||||
jax_test(
|
||||
name = "lax_scipy_test",
|
||||
srcs = ["lax_scipy_test.py"],
|
||||
backend_tags = {
|
||||
"tpu": ["noasan"], # Test times out.
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
"cpu": 20,
|
||||
"gpu": 20,
|
||||
"tpu": 20,
|
||||
"iree": 10,
|
||||
},
|
||||
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
@ -419,6 +417,32 @@ jax_test(
|
||||
},
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_scipy_special_functions_test",
|
||||
srcs = ["lax_scipy_special_functions_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 20,
|
||||
"gpu": 20,
|
||||
"tpu": 20,
|
||||
"iree": 10,
|
||||
},
|
||||
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_scipy_spectral_dac_test",
|
||||
srcs = ["lax_scipy_spectral_dac_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 40,
|
||||
"gpu": 40,
|
||||
"tpu": 40,
|
||||
"iree": 40,
|
||||
},
|
||||
deps = [
|
||||
"//jax:internal_test_util",
|
||||
] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "lax_test",
|
||||
srcs = ["lax_test.py"],
|
||||
|
182
tests/lax_scipy_special_functions_test.py
Normal file
182
tests/lax_scipy_special_functions_test.py
Normal file
@ -0,0 +1,182 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import scipy.special as osp_special
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
|
||||
OpRecord = collections.namedtuple(
|
||||
"OpRecord",
|
||||
["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"])
|
||||
|
||||
|
||||
def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), test_name=None):
|
||||
test_name = test_name or name
|
||||
nondiff_argnums = tuple(sorted(set(nondiff_argnums)))
|
||||
return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name)
|
||||
|
||||
|
||||
float_dtypes = jtu.dtypes.floating
|
||||
int_dtypes = jtu.dtypes.integer
|
||||
|
||||
# TODO(phawkins): we should probably separate out the function domains used for
|
||||
# autodiff tests from the function domains used for equivalence testing. For
|
||||
# example, logit should closely match its scipy equivalent everywhere, but we
|
||||
# don't expect numerical gradient tests to pass for inputs very close to 0.
|
||||
|
||||
JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record(
|
||||
"betaln", 2, float_dtypes, jtu.rand_positive, False
|
||||
),
|
||||
op_record(
|
||||
"betainc", 3, float_dtypes, jtu.rand_positive, False
|
||||
),
|
||||
op_record(
|
||||
"digamma", 1, float_dtypes, jtu.rand_positive, True
|
||||
),
|
||||
op_record(
|
||||
"gammainc", 2, float_dtypes, jtu.rand_positive, True
|
||||
),
|
||||
op_record(
|
||||
"gammaincc", 2, float_dtypes, jtu.rand_positive, True
|
||||
),
|
||||
op_record(
|
||||
"erf", 1, float_dtypes, jtu.rand_small_positive, True
|
||||
),
|
||||
op_record(
|
||||
"erfc", 1, float_dtypes, jtu.rand_small_positive, True
|
||||
),
|
||||
op_record(
|
||||
"erfinv", 1, float_dtypes, jtu.rand_small_positive, True
|
||||
),
|
||||
op_record(
|
||||
"expit", 1, float_dtypes, jtu.rand_small_positive, True
|
||||
),
|
||||
# TODO: gammaln has slightly high error.
|
||||
op_record(
|
||||
"gammaln", 1, float_dtypes, jtu.rand_positive, False
|
||||
),
|
||||
op_record(
|
||||
"i0", 1, float_dtypes, jtu.rand_default, True
|
||||
),
|
||||
op_record(
|
||||
"i0e", 1, float_dtypes, jtu.rand_default, True
|
||||
),
|
||||
op_record(
|
||||
"i1", 1, float_dtypes, jtu.rand_default, True
|
||||
),
|
||||
op_record(
|
||||
"i1e", 1, float_dtypes, jtu.rand_default, True
|
||||
),
|
||||
op_record(
|
||||
"logit", 1, float_dtypes,
|
||||
functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True),
|
||||
op_record(
|
||||
"log_ndtr", 1, float_dtypes, jtu.rand_default, True
|
||||
),
|
||||
op_record(
|
||||
"ndtri", 1, float_dtypes,
|
||||
functools.partial(jtu.rand_uniform, low=0.05, high=0.95), True,
|
||||
),
|
||||
op_record(
|
||||
"ndtr", 1, float_dtypes, jtu.rand_default, True
|
||||
),
|
||||
# TODO(phawkins): gradient of entr yields NaNs.
|
||||
op_record(
|
||||
"entr", 1, float_dtypes, jtu.rand_default, False
|
||||
),
|
||||
op_record(
|
||||
"polygamma", 2, (int_dtypes, float_dtypes),
|
||||
jtu.rand_positive, True, (0,)),
|
||||
op_record(
|
||||
"xlogy", 2, float_dtypes, jtu.rand_positive, True
|
||||
),
|
||||
op_record(
|
||||
"xlog1py", 2, float_dtypes, jtu.rand_default, True
|
||||
),
|
||||
# TODO: enable gradient test for zeta by restricting the domain of
|
||||
# of inputs to some reasonable intervals
|
||||
op_record("zeta", 2, float_dtypes, jtu.rand_positive, False),
|
||||
# TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise
|
||||
op_record(
|
||||
"expi", 1, [np.float32],
|
||||
functools.partial(jtu.rand_not_small, offset=0.1), True),
|
||||
op_record("exp1", 1, [np.float32], jtu.rand_positive, True),
|
||||
op_record(
|
||||
"expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)),
|
||||
]
|
||||
|
||||
|
||||
class LaxScipySpcialFunctionsTest(jtu.JaxTestCase):
|
||||
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes):
|
||||
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op=rec.name, rng_factory=rec.rng_factory,
|
||||
test_autodiff=rec.test_autodiff,
|
||||
nondiff_argnums=rec.nondiff_argnums)],
|
||||
shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs),
|
||||
dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
|
||||
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)),
|
||||
)
|
||||
for rec in JAX_SPECIAL_FUNCTION_RECORDS
|
||||
))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes,
|
||||
test_autodiff, nondiff_argnums):
|
||||
scipy_op = getattr(osp_special, op)
|
||||
lax_op = getattr(lsp_special, op)
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
args = args_maker()
|
||||
self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)
|
||||
|
||||
if test_autodiff:
|
||||
def partial_lax_op(*vals):
|
||||
list_args = list(vals)
|
||||
for i in nondiff_argnums:
|
||||
list_args.insert(i, args[i])
|
||||
return lax_op(*list_args)
|
||||
|
||||
assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
|
||||
diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums]
|
||||
jtu.check_grads(partial_lax_op, diff_args, order=1,
|
||||
atol=jtu.if_device_under_test("tpu", .1, 1e-3),
|
||||
rtol=.1, eps=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
64
tests/lax_scipy_spectral_dac_test.py
Normal file
64
tests/lax_scipy_spectral_dac_test.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import eigh as lax_eigh
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
linear_sizes = [16, 97, 128]
|
||||
|
||||
|
||||
class LaxScipySpectralDacTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
linear_size=linear_sizes,
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
termination_size=[1, 19],
|
||||
)
|
||||
def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
|
||||
if jtu.device_under_test() != "tpu" and termination_size != 1:
|
||||
raise unittest.SkipTest(
|
||||
"Termination sizes greater than 1 only work on TPU")
|
||||
|
||||
rng = self.rng()
|
||||
H = rng.randn(linear_size, linear_size)
|
||||
H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
self.assertRaises(
|
||||
NotImplementedError, lax_eigh.eigh, H)
|
||||
return
|
||||
evs, V = lax_eigh.eigh(H, termination_size=termination_size)
|
||||
ev_exp, _ = jnp.linalg.eigh(H)
|
||||
HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
|
||||
vV = evs.astype(V.dtype)[None, :] * V
|
||||
eps = jnp.finfo(H.dtype).eps
|
||||
atol = jnp.linalg.norm(H) * eps
|
||||
self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
|
||||
self.assertAllClose(
|
||||
HV, vV, atol=atol * (140 if jnp.issubdtype(dtype, jnp.complexfloating)
|
||||
else 40))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -13,14 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import scipy.special as osp_special
|
||||
@ -34,7 +32,6 @@ from jax.tree_util import tree_map
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
from jax.scipy import cluster as lsp_cluster
|
||||
from jax._src.lax import eigh as lax_eigh
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -61,8 +58,6 @@ sides = ["right", "left"]
|
||||
methods = ["qdwh", "svd"]
|
||||
seeds = [1, 10]
|
||||
|
||||
linear_sizes = [16, 128, 256]
|
||||
|
||||
|
||||
def _initialize_polar_test(rng, shape, n_zero_svs, degeneracy, geometric_spectrum,
|
||||
max_sv, nonzero_condition_number, dtype):
|
||||
@ -94,67 +89,10 @@ def _initialize_polar_test(rng, shape, n_zero_svs, degeneracy, geometric_spectru
|
||||
spectrum = jnp.array(svs).astype(dtype)
|
||||
return result, spectrum
|
||||
|
||||
OpRecord = collections.namedtuple(
|
||||
"OpRecord",
|
||||
["name", "nargs", "dtypes", "rng_factory", "test_autodiff", "nondiff_argnums", "test_name"])
|
||||
|
||||
|
||||
def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), test_name=None):
|
||||
test_name = test_name or name
|
||||
nondiff_argnums = tuple(sorted(set(nondiff_argnums)))
|
||||
return OpRecord(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums, test_name)
|
||||
|
||||
# TODO(phawkins): we should probably separate out the function domains used for
|
||||
# autodiff tests from the function domains used for equivalence testing. For
|
||||
# example, logit should closely match its scipy equivalent everywhere, but we
|
||||
# don't expect numerical gradient tests to pass for inputs very close to 0.
|
||||
|
||||
JAX_SPECIAL_FUNCTION_RECORDS = [
|
||||
op_record("betaln", 2, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("betainc", 3, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("digamma", 1, float_dtypes, jtu.rand_positive, True),
|
||||
op_record("gammainc", 2, float_dtypes, jtu.rand_positive, True),
|
||||
op_record("gammaincc", 2, float_dtypes, jtu.rand_positive, True),
|
||||
op_record("erf", 1, float_dtypes, jtu.rand_small_positive, True),
|
||||
op_record("erfc", 1, float_dtypes, jtu.rand_small_positive, True),
|
||||
op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive, True),
|
||||
op_record("expit", 1, float_dtypes, jtu.rand_small_positive, True),
|
||||
# TODO: gammaln has slightly high error.
|
||||
op_record("gammaln", 1, float_dtypes, jtu.rand_positive, False),
|
||||
op_record("i0", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("i0e", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("i1", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("i1e", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("logit", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
|
||||
high=0.95), True),
|
||||
op_record("log_ndtr", 1, float_dtypes, jtu.rand_default, True),
|
||||
op_record("ndtri", 1, float_dtypes, partial(jtu.rand_uniform, low=0.05,
|
||||
high=0.95),
|
||||
True),
|
||||
op_record("ndtr", 1, float_dtypes, jtu.rand_default, True),
|
||||
# TODO(phawkins): gradient of entr yields NaNs.
|
||||
op_record("entr", 1, float_dtypes, jtu.rand_default, False),
|
||||
op_record("polygamma", 2, (int_dtypes, float_dtypes), jtu.rand_positive, True, (0,)),
|
||||
op_record("xlogy", 2, float_dtypes, jtu.rand_positive, True),
|
||||
op_record("xlog1py", 2, float_dtypes, jtu.rand_default, True),
|
||||
# TODO: enable gradient test for zeta by restricting the domain of
|
||||
# of inputs to some reasonable intervals
|
||||
op_record("zeta", 2, float_dtypes, jtu.rand_positive, False),
|
||||
# TODO: float64 produces aborts on gpu, potentially related to use of jnp.piecewise
|
||||
op_record("expi", 1, [np.float32], partial(jtu.rand_not_small, offset=0.1),
|
||||
True),
|
||||
op_record("exp1", 1, [np.float32], jtu.rand_positive, True),
|
||||
op_record("expn", 2, (int_dtypes, [np.float32]), jtu.rand_positive, True, (0,)),
|
||||
]
|
||||
|
||||
|
||||
class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
"""Tests for LAX-backed Scipy implementation."""
|
||||
|
||||
def _GetArgsMaker(self, rng, shapes, dtypes):
|
||||
return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shapes=shapes, axis=axis, use_b=use_b)
|
||||
for shape_group in compatible_shapes
|
||||
@ -232,43 +170,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
result = lsp_special.logsumexp(1.0, b=1.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
@parameterized.parameters(itertools.chain.from_iterable(
|
||||
jtu.sample_product_testcases(
|
||||
[dict(op=rec.name, rng_factory=rec.rng_factory,
|
||||
test_autodiff=rec.test_autodiff,
|
||||
nondiff_argnums=rec.nondiff_argnums)],
|
||||
shapes=itertools.combinations_with_replacement(all_shapes, rec.nargs),
|
||||
dtypes=(itertools.combinations_with_replacement(rec.dtypes, rec.nargs)
|
||||
if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes)),
|
||||
)
|
||||
for rec in JAX_SPECIAL_FUNCTION_RECORDS
|
||||
))
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion
|
||||
def testScipySpecialFun(self, op, rng_factory, shapes, dtypes,
|
||||
test_autodiff, nondiff_argnums):
|
||||
scipy_op = getattr(osp_special, op)
|
||||
lax_op = getattr(lsp_special, op)
|
||||
rng = rng_factory(self.rng())
|
||||
args_maker = self._GetArgsMaker(rng, shapes, dtypes)
|
||||
args = args_maker()
|
||||
self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
|
||||
check_dtypes=False)
|
||||
self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)
|
||||
|
||||
if test_autodiff:
|
||||
def partial_lax_op(*vals):
|
||||
list_args = list(vals)
|
||||
for i in nondiff_argnums:
|
||||
list_args.insert(i, args[i])
|
||||
return lax_op(*list_args)
|
||||
|
||||
assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
|
||||
diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums]
|
||||
jtu.check_grads(partial_lax_op, diff_args, order=1,
|
||||
atol=jtu.if_device_under_test("tpu", .1, 1e-3),
|
||||
rtol=.1, eps=1e-3)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=float_dtypes,
|
||||
@ -581,34 +482,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(
|
||||
matrix, recon, atol=tol * jnp.linalg.norm(matrix))
|
||||
|
||||
@jtu.sample_product(
|
||||
linear_size=linear_sizes,
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
termination_size=[1, 19],
|
||||
)
|
||||
def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
|
||||
if jtu.device_under_test() != "tpu" and termination_size != 1:
|
||||
raise unittest.SkipTest(
|
||||
"Termination sizes greater than 1 only work on TPU")
|
||||
|
||||
rng = self.rng()
|
||||
H = rng.randn(linear_size, linear_size)
|
||||
H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
|
||||
if jnp.dtype(dtype).name in ("bfloat16", "float16"):
|
||||
self.assertRaises(
|
||||
NotImplementedError, lax_eigh.eigh, H)
|
||||
return
|
||||
evs, V = lax_eigh.eigh(H, termination_size=termination_size)
|
||||
ev_exp, eV_exp = jnp.linalg.eigh(H)
|
||||
HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
|
||||
vV = evs.astype(V.dtype)[None, :] * V
|
||||
eps = jnp.finfo(H.dtype).eps
|
||||
atol = jnp.linalg.norm(H) * eps
|
||||
self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
|
||||
self.assertAllClose(
|
||||
HV, vV, atol=atol * (140 if jnp.issubdtype(dtype, jnp.complexfloating)
|
||||
else 30))
|
||||
|
||||
@jtu.sample_product(
|
||||
n_obs=[1, 3, 5],
|
||||
n_codes=[1, 2, 4],
|
||||
@ -645,6 +518,5 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(actual, nan_array, check_dtypes=False)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user