mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Merge pull request #24580 from dfm:fix-ffi-test-segfault
PiperOrigin-RevId: 691062859
This commit is contained in:
commit
c67cf51f15
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
@ -35,6 +34,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax import linalg as lax_linalg_internal
|
||||
|
||||
@ -261,10 +261,6 @@ class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testVectorizedDeprecation(self):
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 13:
|
||||
# TODO(b/376025274): Remove the skip once the bug is fixed.
|
||||
raise unittest.SkipTest("Crashes on Python 3.13")
|
||||
|
||||
x = self.rng().randn(3, 5, 4).astype(np.float32)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
ffi_call_geqrf(x, vectorized=True)
|
||||
@ -332,6 +328,9 @@ class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
def ffi_call_geqrf(x, **kwargs):
|
||||
if jtu.test_device_matches(["cpu"]):
|
||||
lapack._lapack.initialize()
|
||||
|
||||
assert x.dtype == np.float32
|
||||
ndim = x.ndim
|
||||
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user