mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Test: fix casting warning in betainc test
This commit is contained in:
parent
c4d2fc7364
commit
8b74b93501
@ -278,8 +278,11 @@ def lax_ops():
|
||||
"betainc",
|
||||
3,
|
||||
float_dtypes,
|
||||
test_util.rand_positive,
|
||||
{np.float64: 1e-14},
|
||||
test_util.rand_uniform,
|
||||
{
|
||||
np.float32: 1e-5,
|
||||
np.float64: 1e-12,
|
||||
},
|
||||
),
|
||||
op_record(
|
||||
"igamma",
|
||||
|
@ -105,7 +105,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
tol = tol or jtu.default_tolerance()
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
if dtype in (np.float32, np.complex64) and op_name in (
|
||||
"acosh", "asinh", "cos", "cosh", "digamma", "exp", "exp2", "igamma",
|
||||
"acosh", "asinh", "betainc", "cos", "cosh", "digamma", "exp", "exp2", "igamma",
|
||||
"igammac", "log", "log1p", "logistic", "pow", "sin", "sinh", "tan"):
|
||||
tol = jtu.join_tolerance(tol, 2e-4)
|
||||
elif op_name == "asinh" and dtype == np.float16:
|
||||
|
Loading…
x
Reference in New Issue
Block a user