mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
This commit is contained in:
parent
2106a25977
commit
30037547d7
@ -207,7 +207,7 @@ as in the following example:
|
||||
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
|
||||
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
||||
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32>
|
||||
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
@ -60,7 +60,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -2492,11 +2491,8 @@ def maybe_recover_user_shardings(
|
||||
|
||||
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
|
||||
xl: DeviceLocalLayout) -> bool:
|
||||
if xla_extension_version >= 274:
|
||||
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
|
||||
return ul.major_to_minor == xl.major_to_minor
|
||||
else:
|
||||
return ul == xl
|
||||
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
|
||||
return ul.major_to_minor == xl.major_to_minor
|
||||
else:
|
||||
return ul == xl
|
||||
|
||||
|
@ -21,7 +21,6 @@ from jax._src.dtypes import iinfo, issubdtype
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
Shape = tuple[int, ...]
|
||||
|
||||
@ -31,96 +30,64 @@ class AutoLayout:
|
||||
return "AUTO"
|
||||
|
||||
|
||||
if xla_extension_version >= 274:
|
||||
class DeviceLocalLayout:
|
||||
major_to_minor: tuple[int, ...]
|
||||
_tiling: tuple[tuple[int, ...], ...] | None
|
||||
_sub_byte_element_size_in_bits: int
|
||||
class DeviceLocalLayout:
|
||||
major_to_minor: tuple[int, ...]
|
||||
_tiling: tuple[tuple[int, ...], ...] | None
|
||||
_sub_byte_element_size_in_bits: int
|
||||
|
||||
AUTO = AutoLayout()
|
||||
AUTO = AutoLayout()
|
||||
|
||||
def __init__(self, major_to_minor: tuple[int, ...],
|
||||
_tiling: tuple[tuple[int, ...], ...] | None = None,
|
||||
_sub_byte_element_size_in_bits: int = 0):
|
||||
self.major_to_minor = tuple(major_to_minor)
|
||||
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
|
||||
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits
|
||||
def __init__(self, major_to_minor: tuple[int, ...],
|
||||
_tiling: tuple[tuple[int, ...], ...] | None = None,
|
||||
_sub_byte_element_size_in_bits: int = 0):
|
||||
self.major_to_minor = tuple(major_to_minor)
|
||||
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
|
||||
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits
|
||||
|
||||
@staticmethod
|
||||
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
|
||||
xla_layout = pjrt_layout._xla_layout()
|
||||
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
|
||||
xla_layout.tiling(),
|
||||
xla_layout.element_size_in_bits())
|
||||
@staticmethod
|
||||
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
|
||||
xla_layout = pjrt_layout._xla_layout()
|
||||
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
|
||||
xla_layout.tiling(),
|
||||
xla_layout.element_size_in_bits())
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
|
||||
f' _tiling={self._tiling},'
|
||||
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
|
||||
)
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
|
||||
f' _tiling={self._tiling},'
|
||||
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.major_to_minor, self._tiling,
|
||||
self._sub_byte_element_size_in_bits))
|
||||
def __hash__(self):
|
||||
return hash((self.major_to_minor, self._tiling,
|
||||
self._sub_byte_element_size_in_bits))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return (self.major_to_minor == other.major_to_minor and
|
||||
self._tiling == other._tiling and
|
||||
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return (self.major_to_minor == other.major_to_minor and
|
||||
self._tiling == other._tiling and
|
||||
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
|
||||
|
||||
def _to_xla_layout(self, dtype) -> str:
|
||||
if self._tiling is None:
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1])
|
||||
def _to_xla_layout(self, dtype) -> str:
|
||||
if self._tiling is None:
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1])
|
||||
else:
|
||||
if self._sub_byte_element_size_in_bits != 0:
|
||||
sub_byte_size = self._sub_byte_element_size_in_bits
|
||||
elif issubdtype(dtype, np.integer):
|
||||
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
|
||||
else:
|
||||
if self._sub_byte_element_size_in_bits != 0:
|
||||
sub_byte_size = self._sub_byte_element_size_in_bits
|
||||
elif issubdtype(dtype, np.integer):
|
||||
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
|
||||
else:
|
||||
sub_byte_size = 0
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
|
||||
sub_byte_size)
|
||||
return str(xla_layout)
|
||||
sub_byte_size = 0
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
|
||||
sub_byte_size)
|
||||
return str(xla_layout)
|
||||
|
||||
def check_compatible_aval(self, aval_shape: Shape):
|
||||
if len(self.major_to_minor) != len(aval_shape):
|
||||
raise ValueError(
|
||||
f'Length of major_to_minor and the rank of the value should match.'
|
||||
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
|
||||
|
||||
else:
|
||||
class DeviceLocalLayout: # type: ignore
|
||||
layout: xc.PjRtLayout
|
||||
|
||||
AUTO = AutoLayout()
|
||||
|
||||
def __init__(self, layout: xc.PjRtLayout):
|
||||
self._layout = layout
|
||||
self._layout_str = str(self._layout)
|
||||
|
||||
@staticmethod
|
||||
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
|
||||
return DeviceLocalLayout(pjrt_layout) # type: ignore
|
||||
|
||||
def __repr__(self):
|
||||
return f'DeviceLocalLayout({self._layout_str})'
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._layout)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, DeviceLocalLayout):
|
||||
return False
|
||||
return self._layout == other._layout
|
||||
|
||||
def _to_xla_layout(self, dtype) -> str:
|
||||
return self._layout_str
|
||||
|
||||
def check_compatible_aval(self, aval_shape: Shape):
|
||||
pass
|
||||
def check_compatible_aval(self, aval_shape: Shape):
|
||||
if len(self.major_to_minor) != len(aval_shape):
|
||||
raise ValueError(
|
||||
f'Length of major_to_minor and the rank of the value should match.'
|
||||
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
|
||||
|
||||
|
||||
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
|
||||
|
@ -63,7 +63,6 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, GSPMDSharding,
|
||||
@ -714,7 +713,7 @@ def _infer_params(
|
||||
resource_env = None
|
||||
pjit_mesh = None
|
||||
|
||||
skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value
|
||||
skip_cache = config.dynamic_shapes.value
|
||||
if not skip_cache:
|
||||
signature, dynargs = jax_jit.parse_arguments(
|
||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||
|
@ -32,7 +32,6 @@ from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.op_shardings import (
|
||||
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -1065,8 +1064,6 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
||||
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
||||
return parsed_pspec
|
||||
|
||||
if xla_extension_version < 279:
|
||||
preprocess_with_manual = preprocess
|
||||
|
||||
def prepare_axis_resources(axis_resources,
|
||||
arg_name,
|
||||
|
@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path):
|
||||
|
||||
|
||||
__version__ = _get_version_string()
|
||||
_minimum_jaxlib_version = "0.4.30"
|
||||
_minimum_jaxlib_version = "0.4.31"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
@ -62,7 +62,6 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.custom_batching
|
||||
@ -2641,7 +2640,6 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31")
|
||||
def test_jit_infer_params_cache(self):
|
||||
def f(x):
|
||||
return x
|
||||
@ -4427,7 +4425,6 @@ class APITest(jtu.JaxTestCase):
|
||||
g = jax.grad(f, argnums=-1)
|
||||
g(x, y) # doesn't crash
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
|
||||
def test_jit_negative_static_argnums(self):
|
||||
@partial(jax.jit, static_argnums=-1)
|
||||
def g(x, y):
|
||||
|
@ -15,7 +15,6 @@
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -27,9 +26,6 @@ class DeviceTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(pobudzey): Add a test for rocm devices when available.
|
||||
if jtu.is_device_cuda():
|
||||
if xla_extension_version < 276:
|
||||
self.skipTest('requires jaxlib 0.4.31')
|
||||
|
||||
self.assertEqual(device.platform, 'gpu')
|
||||
self.assertEqual(repr(device), 'CudaDevice(id=0)')
|
||||
elif jtu.test_device_matches(['tpu']):
|
||||
|
@ -19,7 +19,6 @@ update these tests.
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@ -64,7 +63,6 @@ from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -591,8 +589,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
self.run_one_test(func, data)
|
||||
|
||||
def test_cuda_threefry2x32(self):
|
||||
logging.info("test_cuda_threefry2x32: xla_extension_version: %s",
|
||||
xla_extension_version)
|
||||
def func(x):
|
||||
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import jax
|
||||
@ -30,7 +29,6 @@ from jax._src import linear_util
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.extend import ffi
|
||||
|
||||
@ -124,7 +122,6 @@ class FfiTest(jtu.JaxTestCase):
|
||||
dtype=(np.int32,),
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
|
||||
def testFfiCall(self, shape, dtype):
|
||||
pivots_size = shape[-1]
|
||||
permutation_size = 2 * pivots_size
|
||||
@ -140,7 +137,6 @@ class FfiTest(jtu.JaxTestCase):
|
||||
vectorized=(False, True),
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
|
||||
def testFfiCallBatching(self, shape, dtype, vectorized):
|
||||
shape = (10,) + shape
|
||||
pivots_size = shape[-1]
|
||||
|
@ -25,7 +25,6 @@ from jax._src import config
|
||||
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -405,9 +404,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out, inp.T)
|
||||
|
||||
def test_device_put_user_concrete_layout(self):
|
||||
if xla_extension_version < 274:
|
||||
self.skipTest('Requires xla_extension_version >= 274')
|
||||
|
||||
shape = (8, 128)
|
||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
dll = DLL(major_to_minor=(1, 0))
|
||||
@ -437,9 +433,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
custom_dll.major_to_minor)
|
||||
|
||||
def test_compatible_aval_error(self):
|
||||
if xla_extension_version < 274:
|
||||
self.skipTest('Requires xla_extension_version >= 274')
|
||||
|
||||
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
||||
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
||||
inp = np.arange(8)
|
||||
@ -454,9 +447,6 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
f(inp)
|
||||
|
||||
def test_incompatible_aval_error_device_put(self):
|
||||
if xla_extension_version < 274:
|
||||
self.skipTest('Requires xla_extension_version >= 274')
|
||||
|
||||
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
||||
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
||||
inp = np.arange(8)
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -34,7 +33,6 @@ from jax._src import config
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -1625,7 +1623,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
(a, b),
|
||||
(a, b))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
|
||||
def testTriangularSolveSingularBatched(self):
|
||||
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
|
||||
y = jnp.array([[1], [1.]], dtype=np.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user