mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove obsolete jaxlib version checks
This commit is contained in:
parent
640488883e
commit
b9c7b9bb4f
@ -16,9 +16,8 @@ import hashlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
import zlib
|
||||
|
||||
import numpy as np
|
||||
@ -38,14 +37,7 @@ from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir import passmanager as pm
|
||||
|
||||
# TODO(phawkins): remove the conditional import after jaxlib 0.4.9 is the
|
||||
# minimum.
|
||||
mlir_jax: Any
|
||||
try:
|
||||
from jax._src.lib.mlir import jax as mlir_jax
|
||||
except ImportError:
|
||||
mlir_jax = None
|
||||
from jax._src.lib.mlir import jax as mlir_jax
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -175,18 +167,15 @@ def _serialize_ir(m: ir.Module) -> bytes:
|
||||
return output.getvalue()
|
||||
|
||||
def _canonicalize_ir(m_original: ir.Module) -> bytes:
|
||||
# TODO(phawkins): remove the 'else' branch when jaxlib 0.4.9 is the minimum.
|
||||
if mlir_jax is not None:
|
||||
with m_original.context:
|
||||
m = m_original.operation.clone()
|
||||
passes = pm.PassManager.parse(
|
||||
"builtin.module(func.func(jax-strip-locations))"
|
||||
)
|
||||
passes.run(m.operation)
|
||||
return _serialize_ir(m)
|
||||
else:
|
||||
bytecode = _serialize_ir(m_original)
|
||||
return re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", bytecode)
|
||||
# mlir_jax import is required to register jax-strip-locations
|
||||
assert mlir_jax is not None
|
||||
with m_original.context:
|
||||
m = m_original.operation.clone()
|
||||
passes = pm.PassManager.parse(
|
||||
"builtin.module(func.func(jax-strip-locations))"
|
||||
)
|
||||
passes.run(m.operation)
|
||||
return _serialize_ir(m)
|
||||
|
||||
def _hash_computation(hash_obj, module):
|
||||
if config.jax_compilation_cache_include_metadata_in_key:
|
||||
|
@ -83,12 +83,7 @@ version = check_jaxlib_version(
|
||||
import jaxlib.cpu_feature_guard as cpu_feature_guard
|
||||
cpu_feature_guard.check_cpu_features()
|
||||
|
||||
# TODO(phawkins): remove after minimium jaxlib version is 0.4.9 or newer.
|
||||
try:
|
||||
import jaxlib.utils as utils
|
||||
except ImportError:
|
||||
utils = None
|
||||
|
||||
import jaxlib.utils as utils
|
||||
import jaxlib.xla_client as xla_client
|
||||
import jaxlib.lapack as lapack
|
||||
|
||||
|
@ -48,7 +48,6 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.array_methods import (
|
||||
@ -1078,18 +1077,10 @@ def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2):
|
||||
length = int(out_len) # will be passed statically
|
||||
output_shape = None
|
||||
|
||||
if (jaxlib_version >= (0, 4, 9)):
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape)
|
||||
else:
|
||||
if output_shape is not None:
|
||||
raise ValueError("native lowering with shape polymorphism "
|
||||
"for threefry on GPU requires jaxlib version 0.4.9")
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length)
|
||||
return lowering_func(
|
||||
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
|
||||
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
|
||||
output_shape)
|
||||
|
||||
threefry2x32_p = core.Primitive("threefry2x32")
|
||||
threefry2x32_p.multiple_results = True
|
||||
|
@ -58,17 +58,7 @@ if TYPE_CHECKING:
|
||||
return list(zip(*args))
|
||||
|
||||
else:
|
||||
# TODO(phawkins): remove the hasattr condition after jaxlib 0.4.9 is the
|
||||
# minimum
|
||||
if hasattr(jaxlib_utils, 'safe_zip'):
|
||||
safe_zip = jaxlib_utils.safe_zip
|
||||
else:
|
||||
def safe_zip(*args):
|
||||
args = list(map(list, args))
|
||||
n = len(args[0])
|
||||
for arg in args[1:]:
|
||||
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
||||
return list(zip(*args))
|
||||
safe_zip = jaxlib_utils.safe_zip
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -95,17 +85,7 @@ if TYPE_CHECKING:
|
||||
return list(map(f, *args))
|
||||
|
||||
else:
|
||||
# TODO(phawkins): remove the hasattr condition after jaxlib 0.4.9 is the
|
||||
# minimum
|
||||
if hasattr(jaxlib_utils, 'safe_map'):
|
||||
safe_map = jaxlib_utils.safe_map
|
||||
else:
|
||||
def safe_map(f, *args):
|
||||
args = list(map(list, args))
|
||||
n = len(args[0])
|
||||
for arg in args[1:]:
|
||||
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
||||
return list(map(f, *args))
|
||||
safe_map = jaxlib_utils.safe_map
|
||||
|
||||
def unzip2(xys: Iterable[tuple[T1, T2]]
|
||||
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
||||
|
@ -199,8 +199,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.get_cache_key(computation2, devices, compile_options, backend),
|
||||
)
|
||||
|
||||
@unittest.skipIf(jax._src.lib.version < (0, 4, 9),
|
||||
"Test requires jaxlib 0.4.9")
|
||||
@parameterized.parameters([False, True])
|
||||
def test_identical_computations_different_metadata(self, include_metadata):
|
||||
f = lambda x, y: lax.mul(lax.add(x, y), 2)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import operator
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
@ -122,8 +121,6 @@ class SafeMapTest(jtu.JaxTestCase):
|
||||
util.safe_map(make_tuple, range(4), range(4, 8)),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_map'),
|
||||
"requires jaxlib 0.4.9")
|
||||
def test_safe_map_errors(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "safe_map requires at least 2 arguments"
|
||||
@ -174,8 +171,6 @@ class SafeZipTest(jtu.JaxTestCase):
|
||||
util.safe_zip(range(4), range(4, 8)),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_zip'),
|
||||
"requires jaxlib 0.4.9")
|
||||
def test_safe_zip_errors(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "safe_zip requires at least 1 argument"
|
||||
|
Loading…
x
Reference in New Issue
Block a user