Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57

PiperOrigin-RevId: 657400413
This commit is contained in:
Yash Katariya 2024-07-29 18:43:56 -07:00 committed by jax authors
parent 2106a25977
commit 30037547d7
12 changed files with 54 additions and 123 deletions

View File

@ -207,7 +207,7 @@ as in the following example:
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir()) >>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) { func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32> %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }
} }

View File

@ -60,7 +60,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import xla from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec from jax._src.partition_spec import PartitionSpec
@ -2492,13 +2491,10 @@ def maybe_recover_user_shardings(
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
xl: DeviceLocalLayout) -> bool: xl: DeviceLocalLayout) -> bool:
if xla_extension_version >= 274:
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None: if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
return ul.major_to_minor == xl.major_to_minor return ul.major_to_minor == xl.major_to_minor
else: else:
return ul == xl return ul == xl
else:
return ul == xl
def _check_user_xla_layout(ul, xl, what: str): def _check_user_xla_layout(ul, xl, what: str):
if not is_user_xla_layout_equal(ul, xl): if not is_user_xla_layout_equal(ul, xl):

View File

@ -21,7 +21,6 @@ from jax._src.dtypes import iinfo, issubdtype
from jax._src.sharding import Sharding from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
Shape = tuple[int, ...] Shape = tuple[int, ...]
@ -31,7 +30,6 @@ class AutoLayout:
return "AUTO" return "AUTO"
if xla_extension_version >= 274:
class DeviceLocalLayout: class DeviceLocalLayout:
major_to_minor: tuple[int, ...] major_to_minor: tuple[int, ...]
_tiling: tuple[tuple[int, ...], ...] | None _tiling: tuple[tuple[int, ...], ...] | None
@ -91,37 +89,6 @@ if xla_extension_version >= 274:
f'Length of major_to_minor and the rank of the value should match.' f'Length of major_to_minor and the rank of the value should match.'
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}') f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
else:
class DeviceLocalLayout: # type: ignore
layout: xc.PjRtLayout
AUTO = AutoLayout()
def __init__(self, layout: xc.PjRtLayout):
self._layout = layout
self._layout_str = str(self._layout)
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
return DeviceLocalLayout(pjrt_layout) # type: ignore
def __repr__(self):
return f'DeviceLocalLayout({self._layout_str})'
def __hash__(self):
return hash(self._layout)
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return self._layout == other._layout
def _to_xla_layout(self, dtype) -> str:
return self._layout_str
def check_compatible_aval(self, aval_shape: Shape):
pass
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
ShardingOptions = Union[Sharding, None, AutoSharding] ShardingOptions = Union[Sharding, None, AutoSharding]

View File

@ -63,7 +63,6 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding from jax._src import sharding
from jax._src.sharding_impls import ( from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding, NamedSharding, GSPMDSharding,
@ -714,7 +713,7 @@ def _infer_params(
resource_env = None resource_env = None
pjit_mesh = None pjit_mesh = None
skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value skip_cache = config.dynamic_shapes.value
if not skip_cache: if not skip_cache:
signature, dynargs = jax_jit.parse_arguments( signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums, args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,

View File

@ -32,7 +32,6 @@ from jax._src import tree_util
from jax._src import util from jax._src import util
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.op_shardings import ( from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
from jax._src.partition_spec import PartitionSpec from jax._src.partition_spec import PartitionSpec
@ -1065,8 +1064,6 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
return parsed_pspec return parsed_pspec
if xla_extension_version < 279:
preprocess_with_manual = preprocess
def prepare_axis_resources(axis_resources, def prepare_axis_resources(axis_resources,
arg_name, arg_name,

View File

@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path):
__version__ = _get_version_string() __version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.30" _minimum_jaxlib_version = "0.4.31"
def _version_as_tuple(version_str): def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit()) return tuple(int(i) for i in version_str.split(".") if i.isdigit())

View File

@ -62,7 +62,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.compilation_cache import is_persistent_cache_enabled from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_client from jax._src.lib import xla_client
from jax._src.lib import xla_extension from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching import jax.custom_batching
@ -2641,7 +2640,6 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(count[0], 1) self.assertEqual(count[0], 1)
@unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31")
def test_jit_infer_params_cache(self): def test_jit_infer_params_cache(self):
def f(x): def f(x):
return x return x
@ -4427,7 +4425,6 @@ class APITest(jtu.JaxTestCase):
g = jax.grad(f, argnums=-1) g = jax.grad(f, argnums=-1)
g(x, y) # doesn't crash g(x, y) # doesn't crash
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def test_jit_negative_static_argnums(self): def test_jit_negative_static_argnums(self):
@partial(jax.jit, static_argnums=-1) @partial(jax.jit, static_argnums=-1)
def g(x, y): def g(x, y):

View File

@ -15,7 +15,6 @@
from absl.testing import absltest from absl.testing import absltest
import jax import jax
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
jax.config.parse_flags_with_absl() jax.config.parse_flags_with_absl()
@ -27,9 +26,6 @@ class DeviceTest(jtu.JaxTestCase):
# TODO(pobudzey): Add a test for rocm devices when available. # TODO(pobudzey): Add a test for rocm devices when available.
if jtu.is_device_cuda(): if jtu.is_device_cuda():
if xla_extension_version < 276:
self.skipTest('requires jaxlib 0.4.31')
self.assertEqual(device.platform, 'gpu') self.assertEqual(device.platform, 'gpu')
self.assertEqual(repr(device), 'CudaDevice(id=0)') self.assertEqual(repr(device), 'CudaDevice(id=0)')
elif jtu.test_device_matches(['tpu']): elif jtu.test_device_matches(['tpu']):

View File

@ -19,7 +19,6 @@ update these tests.
import dataclasses import dataclasses
from functools import partial from functools import partial
import itertools import itertools
import logging
import math import math
from absl.testing import absltest, parameterized from absl.testing import absltest, parameterized
@ -64,7 +63,6 @@ from jax._src import config
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.lib import cuda_versions from jax._src.lib import cuda_versions
from jax._src.lib import version as jaxlib_version from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -591,8 +589,6 @@ class CompatTest(bctu.CompatTestBase):
self.run_one_test(func, data) self.run_one_test(func, data)
def test_cuda_threefry2x32(self): def test_cuda_threefry2x32(self):
logging.info("test_cuda_threefry2x32: xla_extension_version: %s",
xla_extension_version)
def func(x): def func(x):
return jax.random.uniform(x, (2, 4), dtype=np.float32) return jax.random.uniform(x, (2, 4), dtype=np.float32)

View File

@ -15,7 +15,6 @@
import os import os
import numpy as np import numpy as np
import unittest
from absl.testing import absltest, parameterized from absl.testing import absltest, parameterized
import jax import jax
@ -30,7 +29,6 @@ from jax._src import linear_util
from jax._src import prng from jax._src import prng
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.extend import ffi from jax._src.extend import ffi
@ -124,7 +122,6 @@ class FfiTest(jtu.JaxTestCase):
dtype=(np.int32,), dtype=(np.int32,),
) )
@jtu.run_on_devices("gpu") @jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCall(self, shape, dtype): def testFfiCall(self, shape, dtype):
pivots_size = shape[-1] pivots_size = shape[-1]
permutation_size = 2 * pivots_size permutation_size = 2 * pivots_size
@ -140,7 +137,6 @@ class FfiTest(jtu.JaxTestCase):
vectorized=(False, True), vectorized=(False, True),
) )
@jtu.run_on_devices("gpu") @jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCallBatching(self, shape, dtype, vectorized): def testFfiCallBatching(self, shape, dtype, vectorized):
shape = (10,) + shape shape = (10,) + shape
pivots_size = shape[-1] pivots_size = shape[-1]

View File

@ -25,7 +25,6 @@ from jax._src import config
from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.util import safe_zip from jax._src.util import safe_zip
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -405,9 +404,6 @@ class LayoutTest(jtu.JaxTestCase):
self.assertArraysEqual(out, inp.T) self.assertArraysEqual(out, inp.T)
def test_device_put_user_concrete_layout(self): def test_device_put_user_concrete_layout(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')
shape = (8, 128) shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape) np_inp = np.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(1, 0)) dll = DLL(major_to_minor=(1, 0))
@ -437,9 +433,6 @@ class LayoutTest(jtu.JaxTestCase):
custom_dll.major_to_minor) custom_dll.major_to_minor)
def test_compatible_aval_error(self): def test_compatible_aval_error(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')
custom_dll = DLL(major_to_minor=(0, 1, 2)) custom_dll = DLL(major_to_minor=(0, 1, 2))
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
inp = np.arange(8) inp = np.arange(8)
@ -454,9 +447,6 @@ class LayoutTest(jtu.JaxTestCase):
f(inp) f(inp)
def test_incompatible_aval_error_device_put(self): def test_incompatible_aval_error_device_put(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')
custom_dll = DLL(major_to_minor=(0, 1, 2)) custom_dll = DLL(major_to_minor=(0, 1, 2))
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0])) l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
inp = np.arange(8) inp = np.arange(8)

View File

@ -16,7 +16,6 @@
from functools import partial from functools import partial
import itertools import itertools
import unittest
import numpy as np import numpy as np
import scipy import scipy
@ -34,7 +33,6 @@ from jax._src import config
from jax._src.lax import linalg as lax_linalg from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import promote_dtypes_inexact from jax._src.numpy.util import promote_dtypes_inexact
config.parse_flags_with_absl() config.parse_flags_with_absl()
@ -1625,7 +1623,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
(a, b), (a, b),
(a, b)) (a, b))
@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
def testTriangularSolveSingularBatched(self): def testTriangularSolveSingularBatched(self):
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32) x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
y = jnp.array([[1], [1.]], dtype=np.float32) y = jnp.array([[1], [1.]], dtype=np.float32)