Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57

PiperOrigin-RevId: 657400413
This commit is contained in:
Yash Katariya 2024-07-29 18:43:56 -07:00 committed by jax authors
parent 2106a25977
commit 30037547d7
12 changed files with 54 additions and 123 deletions

View File

@ -207,7 +207,7 @@ as in the following example:
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32>
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}

View File

@ -60,7 +60,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
@ -2492,11 +2491,8 @@ def maybe_recover_user_shardings(
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
xl: DeviceLocalLayout) -> bool:
if xla_extension_version >= 274:
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
return ul.major_to_minor == xl.major_to_minor
else:
return ul == xl
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
return ul.major_to_minor == xl.major_to_minor
else:
return ul == xl

View File

@ -21,7 +21,6 @@ from jax._src.dtypes import iinfo, issubdtype
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
Shape = tuple[int, ...]
@ -31,96 +30,64 @@ class AutoLayout:
return "AUTO"
if xla_extension_version >= 274:
class DeviceLocalLayout:
major_to_minor: tuple[int, ...]
_tiling: tuple[tuple[int, ...], ...] | None
_sub_byte_element_size_in_bits: int
class DeviceLocalLayout:
major_to_minor: tuple[int, ...]
_tiling: tuple[tuple[int, ...], ...] | None
_sub_byte_element_size_in_bits: int
AUTO = AutoLayout()
AUTO = AutoLayout()
def __init__(self, major_to_minor: tuple[int, ...],
_tiling: tuple[tuple[int, ...], ...] | None = None,
_sub_byte_element_size_in_bits: int = 0):
self.major_to_minor = tuple(major_to_minor)
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits
def __init__(self, major_to_minor: tuple[int, ...],
_tiling: tuple[tuple[int, ...], ...] | None = None,
_sub_byte_element_size_in_bits: int = 0):
self.major_to_minor = tuple(major_to_minor)
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
xla_layout = pjrt_layout._xla_layout()
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
xla_layout.tiling(),
xla_layout.element_size_in_bits())
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
xla_layout = pjrt_layout._xla_layout()
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
xla_layout.tiling(),
xla_layout.element_size_in_bits())
def __repr__(self):
return (
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' _tiling={self._tiling},'
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
)
def __repr__(self):
return (
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' _tiling={self._tiling},'
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
)
def __hash__(self):
return hash((self.major_to_minor, self._tiling,
self._sub_byte_element_size_in_bits))
def __hash__(self):
return hash((self.major_to_minor, self._tiling,
self._sub_byte_element_size_in_bits))
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return (self.major_to_minor == other.major_to_minor and
self._tiling == other._tiling and
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return (self.major_to_minor == other.major_to_minor and
self._tiling == other._tiling and
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
def _to_xla_layout(self, dtype) -> str:
if self._tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
def _to_xla_layout(self, dtype) -> str:
if self._tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
else:
if self._sub_byte_element_size_in_bits != 0:
sub_byte_size = self._sub_byte_element_size_in_bits
elif issubdtype(dtype, np.integer):
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
else:
if self._sub_byte_element_size_in_bits != 0:
sub_byte_size = self._sub_byte_element_size_in_bits
elif issubdtype(dtype, np.integer):
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
else:
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
sub_byte_size)
return str(xla_layout)
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
sub_byte_size)
return str(xla_layout)
def check_compatible_aval(self, aval_shape: Shape):
if len(self.major_to_minor) != len(aval_shape):
raise ValueError(
f'Length of major_to_minor and the rank of the value should match.'
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
else:
class DeviceLocalLayout: # type: ignore
layout: xc.PjRtLayout
AUTO = AutoLayout()
def __init__(self, layout: xc.PjRtLayout):
self._layout = layout
self._layout_str = str(self._layout)
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
return DeviceLocalLayout(pjrt_layout) # type: ignore
def __repr__(self):
return f'DeviceLocalLayout({self._layout_str})'
def __hash__(self):
return hash(self._layout)
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return self._layout == other._layout
def _to_xla_layout(self, dtype) -> str:
return self._layout_str
def check_compatible_aval(self, aval_shape: Shape):
pass
def check_compatible_aval(self, aval_shape: Shape):
if len(self.major_to_minor) != len(aval_shape):
raise ValueError(
f'Length of major_to_minor and the rank of the value should match.'
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation

View File

@ -63,7 +63,6 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
@ -714,7 +713,7 @@ def _infer_params(
resource_env = None
pjit_mesh = None
skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value
skip_cache = config.dynamic_shapes.value
if not skip_cache:
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,

View File

@ -32,7 +32,6 @@ from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
from jax._src.partition_spec import PartitionSpec
@ -1065,8 +1064,6 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
return parsed_pspec
if xla_extension_version < 279:
preprocess_with_manual = preprocess
def prepare_axis_resources(axis_resources,
arg_name,

View File

@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path):
__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.30"
_minimum_jaxlib_version = "0.4.31"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

View File

@ -62,7 +62,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
@ -2641,7 +2640,6 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(count[0], 1)
@unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31")
def test_jit_infer_params_cache(self):
def f(x):
return x
@ -4427,7 +4425,6 @@ class APITest(jtu.JaxTestCase):
g = jax.grad(f, argnums=-1)
g(x, y) # doesn't crash
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def test_jit_negative_static_argnums(self):
@partial(jax.jit, static_argnums=-1)
def g(x, y):

View File

@ -15,7 +15,6 @@
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
jax.config.parse_flags_with_absl()
@ -27,9 +26,6 @@ class DeviceTest(jtu.JaxTestCase):
# TODO(pobudzey): Add a test for rocm devices when available.
if jtu.is_device_cuda():
if xla_extension_version < 276:
self.skipTest('requires jaxlib 0.4.31')
self.assertEqual(device.platform, 'gpu')
self.assertEqual(repr(device), 'CudaDevice(id=0)')
elif jtu.test_device_matches(['tpu']):

View File

@ -19,7 +19,6 @@ update these tests.
import dataclasses
from functools import partial
import itertools
import logging
import math
from absl.testing import absltest, parameterized
@ -64,7 +63,6 @@ from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
@ -591,8 +589,6 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data)
def test_cuda_threefry2x32(self):
logging.info("test_cuda_threefry2x32: xla_extension_version: %s",
xla_extension_version)
def func(x):
return jax.random.uniform(x, (2, 4), dtype=np.float32)

View File

@ -15,7 +15,6 @@
import os
import numpy as np
import unittest
from absl.testing import absltest, parameterized
import jax
@ -30,7 +29,6 @@ from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.extend import ffi
@ -124,7 +122,6 @@ class FfiTest(jtu.JaxTestCase):
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCall(self, shape, dtype):
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
@ -140,7 +137,6 @@ class FfiTest(jtu.JaxTestCase):
vectorized=(False, True),
)
@jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCallBatching(self, shape, dtype, vectorized):
shape = (10,) + shape
pivots_size = shape[-1]

View File

@ -25,7 +25,6 @@ from jax._src import config
from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src import test_util as jtu
from jax._src.util import safe_zip
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
@ -405,9 +404,6 @@ class LayoutTest(jtu.JaxTestCase):
self.assertArraysEqual(out, inp.T)
def test_device_put_user_concrete_layout(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')
shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(1, 0))
@ -437,9 +433,6 @@ class LayoutTest(jtu.JaxTestCase):
custom_dll.major_to_minor)
def test_compatible_aval_error(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')
custom_dll = DLL(major_to_minor=(0, 1, 2))
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
inp = np.arange(8)
@ -454,9 +447,6 @@ class LayoutTest(jtu.JaxTestCase):
f(inp)
def test_incompatible_aval_error_device_put(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')
custom_dll = DLL(major_to_minor=(0, 1, 2))
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
inp = np.arange(8)

View File

@ -16,7 +16,6 @@
from functools import partial
import itertools
import unittest
import numpy as np
import scipy
@ -34,7 +33,6 @@ from jax._src import config
from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import promote_dtypes_inexact
config.parse_flags_with_absl()
@ -1625,7 +1623,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
(a, b),
(a, b))
@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
def testTriangularSolveSingularBatched(self):
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
y = jnp.array([[1], [1.]], dtype=np.float32)