mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
This commit is contained in:
parent
7d81547f91
commit
efab6945ca
@ -32,8 +32,6 @@ from jax._src import path as pathlib
|
||||
from jax._src import profiler
|
||||
from jax._src import traceback_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
import numpy as np
|
||||
@ -199,10 +197,7 @@ def get_compile_options(
|
||||
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
|
||||
if env_options_overrides is None:
|
||||
env_options_overrides = {}
|
||||
if xla_extension_version > 302:
|
||||
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
|
||||
else:
|
||||
env_options_overrides['xla_gpu_graph_min_graph_size'] = '100000'
|
||||
env_options_overrides['xla_gpu_enable_command_buffer'] = ''
|
||||
|
||||
if env_options_overrides is not None:
|
||||
# Some overrides are passed directly on build_options.
|
||||
@ -258,7 +253,7 @@ def get_compile_options(
|
||||
debug_options.xla_detailed_logging = detailed_logging
|
||||
|
||||
# If persistent cache is enabled, also enable additional XLA caching features.
|
||||
if compilation_cache.is_persistent_cache_enabled() and jaxlib_version > (0, 4, 35):
|
||||
if compilation_cache.is_persistent_cache_enabled():
|
||||
# compilation_cache_dir can't be None here, but the type checker is a bit
|
||||
# strict.
|
||||
path = pathlib.Path(config.compilation_cache_dir.value or "")
|
||||
|
@ -47,7 +47,6 @@ from jax._src.lax.lax import (
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -1364,10 +1363,8 @@ def _triangular_solve_cpu_lower(
|
||||
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
||||
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
|
||||
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
|
||||
return lapack.trsm_hlo(
|
||||
*ctx_args, a_aval.dtype, alpha,
|
||||
ctx, a_aval.dtype, alpha,
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
|
||||
b_shape_vals=b_shape_vals)
|
||||
else:
|
||||
@ -2540,9 +2537,6 @@ def _tridiagonal_solve_gpu_lowering(lowering, ctx, dl, d, du, b):
|
||||
|
||||
|
||||
def _tridiagonal_solve_cpu_lowering(ctx, dl, d, du, b, **kwargs):
|
||||
if jaxlib_version <= (0, 4, 38):
|
||||
rule = mlir.lower_fun(_tridiagonal_solve_jax, multiple_results=False)
|
||||
return rule(ctx, dl, d, du, b, **kwargs)
|
||||
b_aval = ctx.avals_in[-1]
|
||||
batch_dims = b_aval.shape[:-2]
|
||||
target_name = lapack.prepare_lapack_call("gtsv_ffi", b_aval.dtype)
|
||||
@ -2755,13 +2749,12 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period.
|
||||
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
|
||||
gees_result = lapack.gees_hlo(*ctx_args, operand_aval.dtype, operand,
|
||||
gees_result = lapack.gees_hlo(ctx, operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable,
|
||||
a_shape_vals=a_shape_vals)
|
||||
if jaxlib_version >= (0, 4, 37) and not ctx.is_forward_compat():
|
||||
if not ctx.is_forward_compat():
|
||||
schur_form, schur_vectors, _eig_vals, _selected_eig_vals, info = gees_result
|
||||
else:
|
||||
# Number of return values depends on value of sort_eig_vals.
|
||||
|
@ -42,7 +42,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
@ -1550,10 +1549,6 @@ def _masked_swap_lowering_rule(
|
||||
if mask is not None:
|
||||
raise NotImplementedError("masked swap with strided store")
|
||||
tpu.StridedStoreOp(val, ref, starts, strides)
|
||||
elif jaxlib_version <= (0, 4, 35):
|
||||
if mask is not None:
|
||||
raise NotImplementedError("masked swap with vector store")
|
||||
vector.StoreOp(val, ref, starts)
|
||||
else:
|
||||
tpu.VectorStoreOp(val, ref, starts, [], mask=mask)
|
||||
return result
|
||||
|
@ -277,21 +277,13 @@ def make_cpu_client(
|
||||
f"{CPU_COLLECTIVES_IMPLEMENTATIONS}.")
|
||||
|
||||
num_devices = NUM_CPU_DEVICES.value if NUM_CPU_DEVICES.value >= 0 else None
|
||||
if xla_client._version < 303 and num_devices is not None:
|
||||
xla_flags = os.getenv("XLA_FLAGS") or ""
|
||||
os.environ["XLA_FLAGS"] = (
|
||||
f"{xla_flags} --xla_force_host_platform_device_count={num_devices}"
|
||||
)
|
||||
num_devices = None
|
||||
# TODO(phawkins): pass num_devices directly when version 303 is the minimum.
|
||||
kwargs = {} if num_devices is None else {"num_devices": num_devices}
|
||||
return xla_client.make_cpu_client(
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
collectives=collectives,
|
||||
**kwargs,
|
||||
num_devices=num_devices,
|
||||
)
|
||||
|
||||
|
||||
|
@ -31,7 +31,6 @@ from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.mesh import Mesh
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
@ -69,8 +68,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
debug_options.xla_dump_hlo_as_long_text = True
|
||||
debug_options.xla_dump_disable_metadata = True
|
||||
debug_options.xla_dump_hlo_pipeline_re = "xyzzy"
|
||||
if jaxlib_version > (0, 4, 35):
|
||||
debug_options.xla_gpu_experimental_autotune_cache_mode = 2
|
||||
debug_options.xla_gpu_experimental_autotune_cache_mode = 2
|
||||
hash2 = self.get_hashed_value(
|
||||
cache_key._hash_serialized_compile_options, compile_options
|
||||
)
|
||||
|
@ -531,8 +531,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
executable.fingerprint, deserialized_executable.fingerprint)
|
||||
|
||||
def test_persistent_cache_enable_xla_caches(self):
|
||||
if jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("Test requires AutotuneCacheMode bindings")
|
||||
s = os.sep
|
||||
with config.compilation_cache_dir("jax-cache"):
|
||||
with config.persistent_cache_enable_xla_caches("none"):
|
||||
@ -603,8 +601,6 @@ class CompilationCacheDisabledTest(CompilationCacheTestCase):
|
||||
self.assertEqual(count_after_second_use, count_after_first_use)
|
||||
|
||||
def test_persistent_cache_enable_xla_caches_disabled(self):
|
||||
if jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("Test requires AutotuneCacheMode bindings")
|
||||
with config.enable_compilation_cache(False):
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
|
@ -71,7 +71,6 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -650,14 +649,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_schur_results)
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_schur_results)
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_schur_lapack_gees.data_2024_11_29[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_schur_results)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
@ -746,14 +742,11 @@ class CompatTest(bctu.CompatTestBase):
|
||||
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_triangular_solve_results)
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_triangular_solve_results)
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name])
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol,
|
||||
check_results=check_triangular_solve_results)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
@ -808,23 +801,18 @@ class CompatTest(bctu.CompatTestBase):
|
||||
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_09_03[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
# TODO(b/344892332): Remove the check after the compatibility period.
|
||||
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
|
||||
if has_xla_ffi_support:
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
with config.export_ignore_forward_compatibility(True):
|
||||
# FFI Kernel test
|
||||
data = self.load_testdata(
|
||||
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01[dtype_name]
|
||||
)
|
||||
self.run_one_test(func, data, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
|
||||
for dtype_name in ("f32", "f64", "c64", "c128"))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_cpu_tridiagonal_solve_lapack_gtsv(self, dtype_name):
|
||||
if jtu.jaxlib_version() <= (0, 4, 38):
|
||||
self.skipTest("Test requires a newer jaxlib version")
|
||||
if not config.enable_x64.value and dtype_name in ["f64", "c128"]:
|
||||
self.skipTest("Test disabled for x32 mode")
|
||||
|
||||
|
@ -50,7 +50,6 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.util import safe_zip, NumpyComplexWarning, tuple_replace
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -1566,8 +1565,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.")
|
||||
elif rank == 2 and not jtu.test_device_matches(["cpu", "gpu"]):
|
||||
self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU and GPU backends.")
|
||||
if rank == 2 and jaxlib_version <= (0, 4, 35) and jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
tol = { np.int8: 2e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 }
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
|
@ -268,8 +268,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.run_on_devices("cpu", "gpu")
|
||||
def testEig(self, shape, dtype, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
n = shape[-1]
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
@ -312,8 +310,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
a = jnp.full(shape, jnp.nan, dtype)
|
||||
results = lax.linalg.eig(
|
||||
a, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
@ -331,8 +327,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
# haven't checked, that might be because of perturbations causing the
|
||||
# ordering of eigenvalues to change, which will trip up check_grads. So we
|
||||
# just test on small-ish matrices.
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
@ -346,8 +340,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.run_on_devices("cpu", "gpu")
|
||||
def testEigvals(self, shape, dtype):
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
a, = args_maker()
|
||||
@ -358,8 +350,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.run_on_devices("cpu", "gpu")
|
||||
def testEigvalsInf(self):
|
||||
# https://github.com/jax-ml/jax/issues/2661
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
x = jnp.array([[jnp.inf]])
|
||||
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
|
||||
|
||||
@ -369,8 +359,6 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
)
|
||||
@jtu.run_on_devices("cpu", "gpu")
|
||||
def testEigBatching(self, shape, dtype):
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
shape = (10,) + shape
|
||||
args = rng(shape, dtype)
|
||||
@ -1697,8 +1685,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
)
|
||||
def testScipyQrModes(self, shape, dtype, mode, pivoting):
|
||||
is_not_cpu_test_device = not jtu.test_device_matches(["cpu"])
|
||||
is_not_valid_jaxlib_version = jtu.jaxlib_version() <= (0, 4, 38)
|
||||
if pivoting and (is_not_cpu_test_device or is_not_valid_jaxlib_version):
|
||||
if pivoting and is_not_cpu_test_device:
|
||||
self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)
|
||||
|
@ -42,8 +42,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testEig(self, shape, dtype, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
if jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
if not gpu_solver.has_magma():
|
||||
self.skipTest("MAGMA is not installed or can't be loaded.")
|
||||
# TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for
|
||||
@ -93,8 +91,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
|
||||
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors):
|
||||
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
|
||||
if jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
if not gpu_solver.has_magma():
|
||||
self.skipTest("MAGMA is not installed or can't be loaded.")
|
||||
# TODO(b/377907938), TODO(danfm): Debug issues MAGMA support for
|
||||
@ -110,8 +106,6 @@ class MagmaLinalgTest(jtu.JaxTestCase):
|
||||
self.assertTrue(np.all(np.isnan(result)))
|
||||
|
||||
def testEigMagmaConfig(self):
|
||||
if jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
|
||||
if not gpu_solver.has_magma():
|
||||
self.skipTest("MAGMA is not installed or can't be loaded.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
@ -272,8 +272,6 @@ class OpsTest(PallasBaseTest):
|
||||
|
||||
@parameterized.product(dtype=[jnp.float32, jnp.bfloat16, jnp.int8])
|
||||
def test_cast_vector_to_mask(self, dtype):
|
||||
if jtu.jaxlib_version() <= (0, 4, 39):
|
||||
self.skipTest("Test requires non-32-bit selection support")
|
||||
shape = (128, 128)
|
||||
bitwidth = pallas_utils.dtype_bitwidth(dtype)
|
||||
if (
|
||||
|
@ -1724,8 +1724,6 @@ class PallasCallTest(PallasBaseTest):
|
||||
np.testing.assert_array_equal(y, jnp.concatenate([x, x], axis=1))
|
||||
|
||||
def test_masked_store(self):
|
||||
if jtu.jaxlib_version() <= (0, 4, 35):
|
||||
self.skipTest("Test requires masked store support")
|
||||
shape = (16, 256)
|
||||
mask_shape = (10, 130)
|
||||
mask_start = (4, 5)
|
||||
|
@ -37,7 +37,6 @@ from jax.experimental.serialize_executable import (
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec
|
||||
import numpy as np
|
||||
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -94,10 +93,7 @@ class PgleTest(jtu.JaxTestCase):
|
||||
'xla_gpu_enable_latency_hiding_scheduler': 'True',
|
||||
}
|
||||
# TODO(b/37664749): Remove this flag once the bug is fixed.
|
||||
if xla_extension_version > 302:
|
||||
compiler_options['xla_gpu_enable_command_buffer'] = ''
|
||||
else:
|
||||
compiler_options['xla_gpu_graph_min_graph_size'] = '100000'
|
||||
compiler_options['xla_gpu_enable_command_buffer'] = ''
|
||||
@partial(
|
||||
jax.jit,
|
||||
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
@ -138,8 +134,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
'xla_gpu_experimental_dump_fdo_profiles': 'True',
|
||||
}
|
||||
# TODO(b/376647494): Remove this flag once the bug is fixed.
|
||||
if xla_extension_version <= 302:
|
||||
compile_options['xla_gpu_graph_min_graph_size'] = '100000'
|
||||
@partial(
|
||||
jax.jit,
|
||||
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
@ -224,8 +218,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
'xla_gpu_experimental_dump_fdo_profiles': 'True',
|
||||
}
|
||||
# TODO(b/376647494): Remove this flag once the bug is fixed.
|
||||
if xla_extension_version <= 302:
|
||||
compiler_options['xla_gpu_graph_min_graph_size'] = '100000'
|
||||
@partial(
|
||||
jax.jit,
|
||||
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
|
||||
|
@ -48,7 +48,6 @@ import jax.numpy as jnp
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
jtu.request_cpu_devices(8)
|
||||
@ -2170,8 +2169,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f(), np.arange(1, dtype=np.int32).reshape(-1, 1))
|
||||
|
||||
def test_partial_auto_ppermute(self):
|
||||
if xla_extension_version < 302:
|
||||
self.skipTest('minimum xla extension version 302')
|
||||
if config.use_shardy_partitioner.value:
|
||||
self.skipTest('Shardy does not support full-to-shard.')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user