diff --git a/docs/export/export.md b/docs/export/export.md index 3a176a5e1..9e6597cef 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -207,7 +207,7 @@ as in the following example: >>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir()) module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { - %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor) -> tensor + %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor) -> tensor return %0 : tensor } } diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c0402a30a..3774fefe7 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -60,7 +60,6 @@ from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -2492,11 +2491,8 @@ def maybe_recover_user_shardings( def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, xl: DeviceLocalLayout) -> bool: - if xla_extension_version >= 274: - if isinstance(ul, DeviceLocalLayout) and ul._tiling is None: - return ul.major_to_minor == xl.major_to_minor - else: - return ul == xl + if isinstance(ul, DeviceLocalLayout) and ul._tiling is None: + return ul.major_to_minor == xl.major_to_minor else: return ul == xl diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 1b81c4e44..c943c5e2e 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -21,7 +21,6 @@ from jax._src.dtypes import iinfo, issubdtype from jax._src.sharding import Sharding from jax._src.sharding_impls import AUTO as AutoSharding, is_auto from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version Shape = tuple[int, ...] @@ -31,96 +30,64 @@ class AutoLayout: return "AUTO" -if xla_extension_version >= 274: - class DeviceLocalLayout: - major_to_minor: tuple[int, ...] - _tiling: tuple[tuple[int, ...], ...] | None - _sub_byte_element_size_in_bits: int +class DeviceLocalLayout: + major_to_minor: tuple[int, ...] + _tiling: tuple[tuple[int, ...], ...] | None + _sub_byte_element_size_in_bits: int - AUTO = AutoLayout() + AUTO = AutoLayout() - def __init__(self, major_to_minor: tuple[int, ...], - _tiling: tuple[tuple[int, ...], ...] | None = None, - _sub_byte_element_size_in_bits: int = 0): - self.major_to_minor = tuple(major_to_minor) - self._tiling = None if _tiling is None else tuple(map(tuple, _tiling)) - self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits + def __init__(self, major_to_minor: tuple[int, ...], + _tiling: tuple[tuple[int, ...], ...] | None = None, + _sub_byte_element_size_in_bits: int = 0): + self.major_to_minor = tuple(major_to_minor) + self._tiling = None if _tiling is None else tuple(map(tuple, _tiling)) + self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits - @staticmethod - def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): - xla_layout = pjrt_layout._xla_layout() - return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types - xla_layout.tiling(), - xla_layout.element_size_in_bits()) + @staticmethod + def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): + xla_layout = pjrt_layout._xla_layout() + return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types + xla_layout.tiling(), + xla_layout.element_size_in_bits()) - def __repr__(self): - return ( - f'DeviceLocalLayout(major_to_minor={self.major_to_minor},' - f' _tiling={self._tiling},' - f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})' - ) + def __repr__(self): + return ( + f'DeviceLocalLayout(major_to_minor={self.major_to_minor},' + f' _tiling={self._tiling},' + f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})' + ) - def __hash__(self): - return hash((self.major_to_minor, self._tiling, - self._sub_byte_element_size_in_bits)) + def __hash__(self): + return hash((self.major_to_minor, self._tiling, + self._sub_byte_element_size_in_bits)) - def __eq__(self, other): - if not isinstance(other, DeviceLocalLayout): - return False - return (self.major_to_minor == other.major_to_minor and - self._tiling == other._tiling and - self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits) + def __eq__(self, other): + if not isinstance(other, DeviceLocalLayout): + return False + return (self.major_to_minor == other.major_to_minor and + self._tiling == other._tiling and + self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits) - def _to_xla_layout(self, dtype) -> str: - if self._tiling is None: - xla_layout = xc.Layout(self.major_to_minor[::-1]) + def _to_xla_layout(self, dtype) -> str: + if self._tiling is None: + xla_layout = xc.Layout(self.major_to_minor[::-1]) + else: + if self._sub_byte_element_size_in_bits != 0: + sub_byte_size = self._sub_byte_element_size_in_bits + elif issubdtype(dtype, np.integer): + sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0 else: - if self._sub_byte_element_size_in_bits != 0: - sub_byte_size = self._sub_byte_element_size_in_bits - elif issubdtype(dtype, np.integer): - sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0 - else: - sub_byte_size = 0 - xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore - sub_byte_size) - return str(xla_layout) + sub_byte_size = 0 + xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore + sub_byte_size) + return str(xla_layout) - def check_compatible_aval(self, aval_shape: Shape): - if len(self.major_to_minor) != len(aval_shape): - raise ValueError( - f'Length of major_to_minor and the rank of the value should match.' - f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}') - -else: - class DeviceLocalLayout: # type: ignore - layout: xc.PjRtLayout - - AUTO = AutoLayout() - - def __init__(self, layout: xc.PjRtLayout): - self._layout = layout - self._layout_str = str(self._layout) - - @staticmethod - def from_pjrt_layout(pjrt_layout: xc.PjRtLayout): - return DeviceLocalLayout(pjrt_layout) # type: ignore - - def __repr__(self): - return f'DeviceLocalLayout({self._layout_str})' - - def __hash__(self): - return hash(self._layout) - - def __eq__(self, other): - if not isinstance(other, DeviceLocalLayout): - return False - return self._layout == other._layout - - def _to_xla_layout(self, dtype) -> str: - return self._layout_str - - def check_compatible_aval(self, aval_shape: Shape): - pass + def check_compatible_aval(self, aval_shape: Shape): + if len(self.major_to_minor) != len(aval_shape): + raise ValueError( + f'Length of major_to_minor and the rank of the value should match.' + f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}') LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index f1db109a8..28ea5ec81 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -63,7 +63,6 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, @@ -714,7 +713,7 @@ def _infer_params( resource_env = None pjit_mesh = None - skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value + skip_cache = config.dynamic_shapes.value if not skip_cache: signature, dynargs = jax_jit.parse_arguments( args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 99a4e473f..4342f9f8c 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -32,7 +32,6 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) from jax._src.partition_spec import PartitionSpec @@ -1065,8 +1064,6 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) return parsed_pspec -if xla_extension_version < 279: - preprocess_with_manual = preprocess def prepare_axis_resources(axis_resources, arg_name, diff --git a/jax/version.py b/jax/version.py index 736984f9f..cc690e02c 100644 --- a/jax/version.py +++ b/jax/version.py @@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.30" +_minimum_jaxlib_version = "0.4.31" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/tests/api_test.py b/tests/api_test.py index 8f4ebcfe5..1224ade33 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -62,7 +62,6 @@ from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.lib import xla_client from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint import jax.custom_batching @@ -2641,7 +2640,6 @@ class APITest(jtu.JaxTestCase): self.assertEqual(count[0], 1) - @unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31") def test_jit_infer_params_cache(self): def f(x): return x @@ -4427,7 +4425,6 @@ class APITest(jtu.JaxTestCase): g = jax.grad(f, argnums=-1) g(x, y) # doesn't crash - @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") def test_jit_negative_static_argnums(self): @partial(jax.jit, static_argnums=-1) def g(x, y): diff --git a/tests/device_test.py b/tests/device_test.py index 2ffa5c1e3..d8f2ae65b 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -15,7 +15,6 @@ from absl.testing import absltest import jax from jax._src import test_util as jtu -from jax._src.lib import xla_extension_version jax.config.parse_flags_with_absl() @@ -27,9 +26,6 @@ class DeviceTest(jtu.JaxTestCase): # TODO(pobudzey): Add a test for rocm devices when available. if jtu.is_device_cuda(): - if xla_extension_version < 276: - self.skipTest('requires jaxlib 0.4.31') - self.assertEqual(device.platform, 'gpu') self.assertEqual(repr(device), 'CudaDevice(id=0)') elif jtu.test_device_matches(['tpu']): diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index dfd211a9c..cf8b824d6 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -19,7 +19,6 @@ update these tests. import dataclasses from functools import partial import itertools -import logging import math from absl.testing import absltest, parameterized @@ -64,7 +63,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions from jax._src.lib import version as jaxlib_version -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -591,8 +589,6 @@ class CompatTest(bctu.CompatTestBase): self.run_one_test(func, data) def test_cuda_threefry2x32(self): - logging.info("test_cuda_threefry2x32: xla_extension_version: %s", - xla_extension_version) def func(x): return jax.random.uniform(x, (2, 4), dtype=np.float32) diff --git a/tests/extend_test.py b/tests/extend_test.py index c419878c8..939b4d3c4 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -15,7 +15,6 @@ import os import numpy as np -import unittest from absl.testing import absltest, parameterized import jax @@ -30,7 +29,6 @@ from jax._src import linear_util from jax._src import prng from jax._src import test_util as jtu from jax._src.interpreters import mlir -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.extend import ffi @@ -124,7 +122,6 @@ class FfiTest(jtu.JaxTestCase): dtype=(np.int32,), ) @jtu.run_on_devices("gpu") - @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") def testFfiCall(self, shape, dtype): pivots_size = shape[-1] permutation_size = 2 * pivots_size @@ -140,7 +137,6 @@ class FfiTest(jtu.JaxTestCase): vectorized=(False, True), ) @jtu.run_on_devices("gpu") - @unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31") def testFfiCallBatching(self, shape, dtype, vectorized): shape = (10,) + shape pivots_size = shape[-1] diff --git a/tests/layout_test.py b/tests/layout_test.py index 97e426bad..7972d44d3 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -25,7 +25,6 @@ from jax._src import config from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -405,9 +404,6 @@ class LayoutTest(jtu.JaxTestCase): self.assertArraysEqual(out, inp.T) def test_device_put_user_concrete_layout(self): - if xla_extension_version < 274: - self.skipTest('Requires xla_extension_version >= 274') - shape = (8, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) dll = DLL(major_to_minor=(1, 0)) @@ -437,9 +433,6 @@ class LayoutTest(jtu.JaxTestCase): custom_dll.major_to_minor) def test_compatible_aval_error(self): - if xla_extension_version < 274: - self.skipTest('Requires xla_extension_version >= 274') - custom_dll = DLL(major_to_minor=(0, 1, 2)) l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) @@ -454,9 +447,6 @@ class LayoutTest(jtu.JaxTestCase): f(inp) def test_incompatible_aval_error_device_put(self): - if xla_extension_version < 274: - self.skipTest('Requires xla_extension_version >= 274') - custom_dll = DLL(major_to_minor=(0, 1, 2)) l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) inp = np.arange(8) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index dd0ae38d9..29e212b14 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -16,7 +16,6 @@ from functools import partial import itertools -import unittest import numpy as np import scipy @@ -34,7 +33,6 @@ from jax._src import config from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.lib import xla_extension_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -1625,7 +1623,6 @@ class ScipyLinalgTest(jtu.JaxTestCase): (a, b), (a, b)) - @unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30") def testTriangularSolveSingularBatched(self): x = jnp.array([[1, 1], [0, 0]], dtype=np.float32) y = jnp.array([[1], [1.]], dtype=np.float32)