Remove obsolete jaxlib version checks

This commit is contained in:
Jake VanderPlas 2023-07-12 11:53:55 -07:00
parent 640488883e
commit b9c7b9bb4f
6 changed files with 18 additions and 70 deletions

View File

@ -16,9 +16,8 @@ import hashlib
import io
import logging
import os
import re
import sys
from typing import Any, Optional
from typing import Optional
import zlib
import numpy as np
@ -38,14 +37,7 @@ from jax._src.lib import xla_extension_version
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager as pm
# TODO(phawkins): remove the conditional import after jaxlib 0.4.9 is the
# minimum.
mlir_jax: Any
try:
from jax._src.lib.mlir import jax as mlir_jax
except ImportError:
mlir_jax = None
from jax._src.lib.mlir import jax as mlir_jax
logger = logging.getLogger(__name__)
@ -175,18 +167,15 @@ def _serialize_ir(m: ir.Module) -> bytes:
return output.getvalue()
def _canonicalize_ir(m_original: ir.Module) -> bytes:
# TODO(phawkins): remove the 'else' branch when jaxlib 0.4.9 is the minimum.
if mlir_jax is not None:
with m_original.context:
m = m_original.operation.clone()
passes = pm.PassManager.parse(
"builtin.module(func.func(jax-strip-locations))"
)
passes.run(m.operation)
return _serialize_ir(m)
else:
bytecode = _serialize_ir(m_original)
return re.sub(b" at 0x[a-f0-9]+>", b" at 0x...>", bytecode)
# mlir_jax import is required to register jax-strip-locations
assert mlir_jax is not None
with m_original.context:
m = m_original.operation.clone()
passes = pm.PassManager.parse(
"builtin.module(func.func(jax-strip-locations))"
)
passes.run(m.operation)
return _serialize_ir(m)
def _hash_computation(hash_obj, module):
if config.jax_compilation_cache_include_metadata_in_key:

View File

@ -83,12 +83,7 @@ version = check_jaxlib_version(
import jaxlib.cpu_feature_guard as cpu_feature_guard
cpu_feature_guard.check_cpu_features()
# TODO(phawkins): remove after minimium jaxlib version is 0.4.9 or newer.
try:
import jaxlib.utils as utils
except ImportError:
utils = None
import jaxlib.utils as utils
import jaxlib.xla_client as xla_client
import jaxlib.lapack as lapack

View File

@ -48,7 +48,6 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir import ir
from jax._src.lib import gpu_prng
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.array_methods import (
@ -1078,18 +1077,10 @@ def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2):
length = int(out_len) # will be passed statically
output_shape = None
if (jaxlib_version >= (0, 4, 9)):
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
output_shape)
else:
if output_shape is not None:
raise ValueError("native lowering with shape polymorphism "
"for threefry on GPU requires jaxlib version 0.4.9")
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length)
return lowering_func(
(_broadcast(k1, k1_aval), _broadcast(k2, k2_aval)),
(_broadcast(x1, x1_aval), _broadcast(x2, x2_aval)), length,
output_shape)
threefry2x32_p = core.Primitive("threefry2x32")
threefry2x32_p.multiple_results = True

View File

@ -58,17 +58,7 @@ if TYPE_CHECKING:
return list(zip(*args))
else:
# TODO(phawkins): remove the hasattr condition after jaxlib 0.4.9 is the
# minimum
if hasattr(jaxlib_utils, 'safe_zip'):
safe_zip = jaxlib_utils.safe_zip
else:
def safe_zip(*args):
args = list(map(list, args))
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
return list(zip(*args))
safe_zip = jaxlib_utils.safe_zip
if TYPE_CHECKING:
@ -95,17 +85,7 @@ if TYPE_CHECKING:
return list(map(f, *args))
else:
# TODO(phawkins): remove the hasattr condition after jaxlib 0.4.9 is the
# minimum
if hasattr(jaxlib_utils, 'safe_map'):
safe_map = jaxlib_utils.safe_map
else:
def safe_map(f, *args):
args = list(map(list, args))
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
return list(map(f, *args))
safe_map = jaxlib_utils.safe_map
def unzip2(xys: Iterable[tuple[T1, T2]]
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:

View File

@ -199,8 +199,6 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.get_cache_key(computation2, devices, compile_options, backend),
)
@unittest.skipIf(jax._src.lib.version < (0, 4, 9),
"Test requires jaxlib 0.4.9")
@parameterized.parameters([False, True])
def test_identical_computations_different_metadata(self, include_metadata):
f = lambda x, y: lax.mul(lax.add(x, y), 2)

View File

@ -13,7 +13,6 @@
# limitations under the License.
import operator
import unittest
from absl.testing import absltest
@ -122,8 +121,6 @@ class SafeMapTest(jtu.JaxTestCase):
util.safe_map(make_tuple, range(4), range(4, 8)),
)
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_map'),
"requires jaxlib 0.4.9")
def test_safe_map_errors(self):
with self.assertRaisesRegex(
TypeError, "safe_map requires at least 2 arguments"
@ -174,8 +171,6 @@ class SafeZipTest(jtu.JaxTestCase):
util.safe_zip(range(4), range(4, 8)),
)
@unittest.skipIf(not hasattr(jaxlib_utils, 'safe_zip'),
"requires jaxlib 0.4.9")
def test_safe_zip_errors(self):
with self.assertRaisesRegex(
TypeError, "safe_zip requires at least 1 argument"