mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 12:56:08 +00:00
Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
This commit is contained in:
parent
2106a25977
commit
30037547d7
@ -207,7 +207,7 @@ as in the following example:
|
|||||||
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
|
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
|
||||||
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||||
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
||||||
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32>
|
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
|
||||||
return %0 : tensor<f32>
|
return %0 : tensor<f32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,7 +60,6 @@ from jax._src.interpreters import mlir
|
|||||||
from jax._src.interpreters import xla
|
from jax._src.interpreters import xla
|
||||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.lib.mlir.dialects import hlo
|
from jax._src.lib.mlir.dialects import hlo
|
||||||
from jax._src.partition_spec import PartitionSpec
|
from jax._src.partition_spec import PartitionSpec
|
||||||
@ -2492,13 +2491,10 @@ def maybe_recover_user_shardings(
|
|||||||
|
|
||||||
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
|
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
|
||||||
xl: DeviceLocalLayout) -> bool:
|
xl: DeviceLocalLayout) -> bool:
|
||||||
if xla_extension_version >= 274:
|
|
||||||
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
|
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
|
||||||
return ul.major_to_minor == xl.major_to_minor
|
return ul.major_to_minor == xl.major_to_minor
|
||||||
else:
|
else:
|
||||||
return ul == xl
|
return ul == xl
|
||||||
else:
|
|
||||||
return ul == xl
|
|
||||||
|
|
||||||
def _check_user_xla_layout(ul, xl, what: str):
|
def _check_user_xla_layout(ul, xl, what: str):
|
||||||
if not is_user_xla_layout_equal(ul, xl):
|
if not is_user_xla_layout_equal(ul, xl):
|
||||||
|
@ -21,7 +21,6 @@ from jax._src.dtypes import iinfo, issubdtype
|
|||||||
from jax._src.sharding import Sharding
|
from jax._src.sharding import Sharding
|
||||||
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
|
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
|
|
||||||
Shape = tuple[int, ...]
|
Shape = tuple[int, ...]
|
||||||
|
|
||||||
@ -31,7 +30,6 @@ class AutoLayout:
|
|||||||
return "AUTO"
|
return "AUTO"
|
||||||
|
|
||||||
|
|
||||||
if xla_extension_version >= 274:
|
|
||||||
class DeviceLocalLayout:
|
class DeviceLocalLayout:
|
||||||
major_to_minor: tuple[int, ...]
|
major_to_minor: tuple[int, ...]
|
||||||
_tiling: tuple[tuple[int, ...], ...] | None
|
_tiling: tuple[tuple[int, ...], ...] | None
|
||||||
@ -91,37 +89,6 @@ if xla_extension_version >= 274:
|
|||||||
f'Length of major_to_minor and the rank of the value should match.'
|
f'Length of major_to_minor and the rank of the value should match.'
|
||||||
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
|
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')
|
||||||
|
|
||||||
else:
|
|
||||||
class DeviceLocalLayout: # type: ignore
|
|
||||||
layout: xc.PjRtLayout
|
|
||||||
|
|
||||||
AUTO = AutoLayout()
|
|
||||||
|
|
||||||
def __init__(self, layout: xc.PjRtLayout):
|
|
||||||
self._layout = layout
|
|
||||||
self._layout_str = str(self._layout)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
|
|
||||||
return DeviceLocalLayout(pjrt_layout) # type: ignore
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f'DeviceLocalLayout({self._layout_str})'
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(self._layout)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, DeviceLocalLayout):
|
|
||||||
return False
|
|
||||||
return self._layout == other._layout
|
|
||||||
|
|
||||||
def _to_xla_layout(self, dtype) -> str:
|
|
||||||
return self._layout_str
|
|
||||||
|
|
||||||
def check_compatible_aval(self, aval_shape: Shape):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
|
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
|
||||||
ShardingOptions = Union[Sharding, None, AutoSharding]
|
ShardingOptions = Union[Sharding, None, AutoSharding]
|
||||||
|
@ -63,7 +63,6 @@ from jax._src.lib.mlir import ir
|
|||||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||||
from jax._src.lib import jax_jit
|
from jax._src.lib import jax_jit
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src import sharding
|
from jax._src import sharding
|
||||||
from jax._src.sharding_impls import (
|
from jax._src.sharding_impls import (
|
||||||
NamedSharding, GSPMDSharding,
|
NamedSharding, GSPMDSharding,
|
||||||
@ -714,7 +713,7 @@ def _infer_params(
|
|||||||
resource_env = None
|
resource_env = None
|
||||||
pjit_mesh = None
|
pjit_mesh = None
|
||||||
|
|
||||||
skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value
|
skip_cache = config.dynamic_shapes.value
|
||||||
if not skip_cache:
|
if not skip_cache:
|
||||||
signature, dynargs = jax_jit.parse_arguments(
|
signature, dynargs = jax_jit.parse_arguments(
|
||||||
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
|
||||||
|
@ -32,7 +32,6 @@ from jax._src import tree_util
|
|||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src.op_shardings import (
|
from jax._src.op_shardings import (
|
||||||
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
|
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
|
||||||
from jax._src.partition_spec import PartitionSpec
|
from jax._src.partition_spec import PartitionSpec
|
||||||
@ -1065,8 +1064,6 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
|
|||||||
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
|
||||||
return parsed_pspec
|
return parsed_pspec
|
||||||
|
|
||||||
if xla_extension_version < 279:
|
|
||||||
preprocess_with_manual = preprocess
|
|
||||||
|
|
||||||
def prepare_axis_resources(axis_resources,
|
def prepare_axis_resources(axis_resources,
|
||||||
arg_name,
|
arg_name,
|
||||||
|
@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path):
|
|||||||
|
|
||||||
|
|
||||||
__version__ = _get_version_string()
|
__version__ = _get_version_string()
|
||||||
_minimum_jaxlib_version = "0.4.30"
|
_minimum_jaxlib_version = "0.4.31"
|
||||||
|
|
||||||
def _version_as_tuple(version_str):
|
def _version_as_tuple(version_str):
|
||||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||||
|
@ -62,7 +62,6 @@ from jax._src.interpreters import partial_eval as pe
|
|||||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
from jax._src.lib import xla_extension
|
from jax._src.lib import xla_extension
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
import jax._src.util as jax_util
|
import jax._src.util as jax_util
|
||||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||||
import jax.custom_batching
|
import jax.custom_batching
|
||||||
@ -2641,7 +2640,6 @@ class APITest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
self.assertEqual(count[0], 1)
|
self.assertEqual(count[0], 1)
|
||||||
|
|
||||||
@unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31")
|
|
||||||
def test_jit_infer_params_cache(self):
|
def test_jit_infer_params_cache(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
return x
|
return x
|
||||||
@ -4427,7 +4425,6 @@ class APITest(jtu.JaxTestCase):
|
|||||||
g = jax.grad(f, argnums=-1)
|
g = jax.grad(f, argnums=-1)
|
||||||
g(x, y) # doesn't crash
|
g(x, y) # doesn't crash
|
||||||
|
|
||||||
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
|
|
||||||
def test_jit_negative_static_argnums(self):
|
def test_jit_negative_static_argnums(self):
|
||||||
@partial(jax.jit, static_argnums=-1)
|
@partial(jax.jit, static_argnums=-1)
|
||||||
def g(x, y):
|
def g(x, y):
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import jax
|
import jax
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
|
|
||||||
jax.config.parse_flags_with_absl()
|
jax.config.parse_flags_with_absl()
|
||||||
|
|
||||||
@ -27,9 +26,6 @@ class DeviceTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
# TODO(pobudzey): Add a test for rocm devices when available.
|
# TODO(pobudzey): Add a test for rocm devices when available.
|
||||||
if jtu.is_device_cuda():
|
if jtu.is_device_cuda():
|
||||||
if xla_extension_version < 276:
|
|
||||||
self.skipTest('requires jaxlib 0.4.31')
|
|
||||||
|
|
||||||
self.assertEqual(device.platform, 'gpu')
|
self.assertEqual(device.platform, 'gpu')
|
||||||
self.assertEqual(repr(device), 'CudaDevice(id=0)')
|
self.assertEqual(repr(device), 'CudaDevice(id=0)')
|
||||||
elif jtu.test_device_matches(['tpu']):
|
elif jtu.test_device_matches(['tpu']):
|
||||||
|
@ -19,7 +19,6 @@ update these tests.
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from absl.testing import absltest, parameterized
|
from absl.testing import absltest, parameterized
|
||||||
@ -64,7 +63,6 @@ from jax._src import config
|
|||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.lib import cuda_versions
|
from jax._src.lib import cuda_versions
|
||||||
from jax._src.lib import version as jaxlib_version
|
from jax._src.lib import version as jaxlib_version
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
|
||||||
@ -591,8 +589,6 @@ class CompatTest(bctu.CompatTestBase):
|
|||||||
self.run_one_test(func, data)
|
self.run_one_test(func, data)
|
||||||
|
|
||||||
def test_cuda_threefry2x32(self):
|
def test_cuda_threefry2x32(self):
|
||||||
logging.info("test_cuda_threefry2x32: xla_extension_version: %s",
|
|
||||||
xla_extension_version)
|
|
||||||
def func(x):
|
def func(x):
|
||||||
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
return jax.random.uniform(x, (2, 4), dtype=np.float32)
|
||||||
|
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
|
||||||
from absl.testing import absltest, parameterized
|
from absl.testing import absltest, parameterized
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
@ -30,7 +29,6 @@ from jax._src import linear_util
|
|||||||
from jax._src import prng
|
from jax._src import prng
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.extend import ffi
|
from jax._src.extend import ffi
|
||||||
|
|
||||||
@ -124,7 +122,6 @@ class FfiTest(jtu.JaxTestCase):
|
|||||||
dtype=(np.int32,),
|
dtype=(np.int32,),
|
||||||
)
|
)
|
||||||
@jtu.run_on_devices("gpu")
|
@jtu.run_on_devices("gpu")
|
||||||
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
|
|
||||||
def testFfiCall(self, shape, dtype):
|
def testFfiCall(self, shape, dtype):
|
||||||
pivots_size = shape[-1]
|
pivots_size = shape[-1]
|
||||||
permutation_size = 2 * pivots_size
|
permutation_size = 2 * pivots_size
|
||||||
@ -140,7 +137,6 @@ class FfiTest(jtu.JaxTestCase):
|
|||||||
vectorized=(False, True),
|
vectorized=(False, True),
|
||||||
)
|
)
|
||||||
@jtu.run_on_devices("gpu")
|
@jtu.run_on_devices("gpu")
|
||||||
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
|
|
||||||
def testFfiCallBatching(self, shape, dtype, vectorized):
|
def testFfiCallBatching(self, shape, dtype, vectorized):
|
||||||
shape = (10,) + shape
|
shape = (10,) + shape
|
||||||
pivots_size = shape[-1]
|
pivots_size = shape[-1]
|
||||||
|
@ -25,7 +25,6 @@ from jax._src import config
|
|||||||
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
from jax._src.layout import Layout, DeviceLocalLayout as DLL
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.util import safe_zip
|
from jax._src.util import safe_zip
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
|
||||||
@ -405,9 +404,6 @@ class LayoutTest(jtu.JaxTestCase):
|
|||||||
self.assertArraysEqual(out, inp.T)
|
self.assertArraysEqual(out, inp.T)
|
||||||
|
|
||||||
def test_device_put_user_concrete_layout(self):
|
def test_device_put_user_concrete_layout(self):
|
||||||
if xla_extension_version < 274:
|
|
||||||
self.skipTest('Requires xla_extension_version >= 274')
|
|
||||||
|
|
||||||
shape = (8, 128)
|
shape = (8, 128)
|
||||||
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
np_inp = np.arange(math.prod(shape)).reshape(shape)
|
||||||
dll = DLL(major_to_minor=(1, 0))
|
dll = DLL(major_to_minor=(1, 0))
|
||||||
@ -437,9 +433,6 @@ class LayoutTest(jtu.JaxTestCase):
|
|||||||
custom_dll.major_to_minor)
|
custom_dll.major_to_minor)
|
||||||
|
|
||||||
def test_compatible_aval_error(self):
|
def test_compatible_aval_error(self):
|
||||||
if xla_extension_version < 274:
|
|
||||||
self.skipTest('Requires xla_extension_version >= 274')
|
|
||||||
|
|
||||||
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
||||||
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
||||||
inp = np.arange(8)
|
inp = np.arange(8)
|
||||||
@ -454,9 +447,6 @@ class LayoutTest(jtu.JaxTestCase):
|
|||||||
f(inp)
|
f(inp)
|
||||||
|
|
||||||
def test_incompatible_aval_error_device_put(self):
|
def test_incompatible_aval_error_device_put(self):
|
||||||
if xla_extension_version < 274:
|
|
||||||
self.skipTest('Requires xla_extension_version >= 274')
|
|
||||||
|
|
||||||
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
custom_dll = DLL(major_to_minor=(0, 1, 2))
|
||||||
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
|
||||||
inp = np.arange(8)
|
inp = np.arange(8)
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools
|
import itertools
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
@ -34,7 +33,6 @@ from jax._src import config
|
|||||||
from jax._src.lax import linalg as lax_linalg
|
from jax._src.lax import linalg as lax_linalg
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src.numpy.util import promote_dtypes_inexact
|
from jax._src.numpy.util import promote_dtypes_inexact
|
||||||
|
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
@ -1625,7 +1623,6 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
|||||||
(a, b),
|
(a, b),
|
||||||
(a, b))
|
(a, b))
|
||||||
|
|
||||||
@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
|
|
||||||
def testTriangularSolveSingularBatched(self):
|
def testTriangularSolveSingularBatched(self):
|
||||||
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
|
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
|
||||||
y = jnp.array([[1], [1.]], dtype=np.float32)
|
y = jnp.array([[1], [1.]], dtype=np.float32)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user