1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Merge pull request from dfm:fix-ffi-test-segfault

PiperOrigin-RevId: 691062859
This commit is contained in:
jax authors 2024-10-29 10:05:47 -07:00
commit c67cf51f15

@ -13,7 +13,6 @@
# limitations under the License.
import os
import sys
import unittest
from functools import partial
@ -35,6 +34,7 @@ from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
@ -261,10 +261,6 @@ class FfiTest(jtu.JaxTestCase):
@jtu.run_on_devices("gpu", "cpu")
def testVectorizedDeprecation(self):
if sys.version_info.major == 3 and sys.version_info.minor == 13:
# TODO(b/376025274): Remove the skip once the bug is fixed.
raise unittest.SkipTest("Crashes on Python 3.13")
x = self.rng().randn(3, 5, 4).astype(np.float32)
with self.assertWarns(DeprecationWarning):
ffi_call_geqrf(x, vectorized=True)
@ -332,6 +328,9 @@ class FfiTest(jtu.JaxTestCase):
def ffi_call_geqrf(x, **kwargs):
if jtu.test_device_matches(["cpu"]):
lapack._lapack.initialize()
assert x.dtype == np.float32
ndim = x.ndim
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)