1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

Bump minimum jaxlib version to v0.4.30.

This corresponds to xla_extension_version 271 and mlir_api_version 57.
This commit is contained in:
Peter Hawkins 2024-06-18 11:31:09 -04:00
parent 7827af9407
commit 07d24e7dcc
20 changed files with 32 additions and 145 deletions

@ -40,7 +40,6 @@ from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension as xe
from jax._src.lib import xla_extension_version
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding,
@ -1120,14 +1119,8 @@ def _array_shard_arg(xs, shardings):
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
if xla_extension_version < 271:
copy_outs = [
xc.copy_array_to_devices_with_sharding(x, d, s) # pytype: disable=module-attr
for x, d, s in safe_zip(batch_xs, batch_devs, batch_shardings)
]
else:
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings)
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings)
for i, copy_out in safe_zip(batch_indices, copy_outs):
assert results[i] is None
results[i] = copy_out

@ -62,7 +62,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_extension_version
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
@ -3175,10 +3174,7 @@ class MeshExecutable(stages.XlaExecutable):
self.unsafe_call.in_handler.input_indices)
else:
fastpath_data = None
if xla_extension_version > 267:
return outs, fastpath_data, False # Do not remove cache entry
else:
return outs, fastpath_data
return outs, fastpath_data, False # Do not remove cache entry
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],

@ -335,11 +335,6 @@ def _approx_top_k_lowering(ctx, operand, *, k,
backend_config=backend_config,
result_shapes=result_shapes)
else:
if xc.mlir_api_version < 57:
raise NotImplementedError(
"approx_top_k with non-constant k requires jaxlib version 0.4.29 or "
"newer")
k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
out = mlir.custom_call(
"stablehlo.dynamic_approx_top_k",

@ -50,7 +50,6 @@ from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
@ -2370,27 +2369,12 @@ def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn):
mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)),
platform=platform)
if xla_extension_version >= 266:
# For jax-metal, until reduce_window legalization is better supported.
register_lowering(partial(associative_scan, reduce_fn), 'METAL')
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
register_lowering(
partial(cumred_reduce_window_impl, reduce_window_fn)
)
else:
# Older XLA versions only have this rewrite for TPU.
register_lowering(
partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu'
)
# Default for platforms not treated specially below.
register_lowering(partial(associative_scan, reduce_fn))
# On GPU, we choose between window reduction and associative scan
# based on the input size.
for platform in ['cuda', 'rocm']:
register_lowering(
partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform
)
# For jax-metal, until reduce_window legalization is better supported.
register_lowering(partial(associative_scan, reduce_fn), 'METAL')
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
register_lowering(
partial(cumred_reduce_window_impl, reduce_window_fn)
)
return reducer_p

@ -44,7 +44,6 @@ from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
@ -510,10 +509,6 @@ def _cholesky_update_cuda_lowering_rule(ctx, r_matrix, w_vector):
raise NotImplementedError(
"Can only lower fast cholesky_update on CUDA."
)
if jaxlib_version < (0, 4, 29):
raise NotImplementedError(
f"The jaxlib version {jaxlib_version} is too old."
"Please update to at least 0.4.29.")
return gpu_linalg.cuda_cholesky_update(
r_matrix, w_vector, r_matrix_aval.dtype)

@ -60,13 +60,12 @@ from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src import sharding
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
NamedSharding, GSPMDSharding,
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
@ -333,10 +332,7 @@ def _cpp_pjit(jit_info: PjitInfo):
jaxpr.consts, jit_info.abstracted_axes,
pgle_profiler)
if xla_extension_version > 267:
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
else:
return outs, maybe_fastpath_data
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
fun = jit_info.fun
cpp_pjit_f = xc._xla.pjit(
@ -1403,10 +1399,6 @@ def _resolve_in_shardings(
# not allow None as the sharding.
if arg_s is None:
continue
if xla_extension_version < 270:
if not isinstance(arg_s, XLACompatibleSharding):
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
'which is not a subclass of XLACompatibleSharding.')
# Don't consider PmapSharding inputs as committed. They will get resharded
# unconditionally.
if isinstance(arg_s, PmapSharding):
@ -1621,10 +1613,7 @@ def _pjit_call_impl(*args, jaxpr,
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
jaxpr.consts, None, pgle_profiler)
if xla_extension_version > 267:
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
else:
return out_flat, fastpath_data
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,

@ -32,7 +32,6 @@ from jax._src import util
from jax._src import xla_bridge
from jax._src import core
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
from jax._src.partition_spec import PartitionSpec
@ -1054,10 +1053,6 @@ def prepare_axis_resources(axis_resources,
if isinstance(entry, PmapSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not '
'allowed.')
if xla_extension_version < 270:
if not isinstance(entry, XLACompatibleSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not a '
'subclass of XLACompatibleSharding.')
new_entries.append(entry)
else:
new_entries.append(ParsedPartitionSpec.from_user_input(

@ -50,7 +50,6 @@ from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
@ -245,32 +244,18 @@ def count_primitive_compiles():
@contextmanager
def count_device_put_fast_path_hit():
if xla_extension_version < 271:
original_fn = xc.copy_array_to_devices_with_sharding
count = [0]
original_fn = xc.batched_copy_array_to_devices_with_sharding
count = [0]
def copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.copy_array_to_devices_with_sharding = original_fn
else:
original_fn = xc.batched_copy_array_to_devices_with_sharding
count = [0]
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.batched_copy_array_to_devices_with_sharding = original_fn
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.batched_copy_array_to_devices_with_sharding = original_fn
@contextmanager

@ -45,7 +45,6 @@ from jax._src.cloud_tpu_init import get_tpu_library_path
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
logger = logging.getLogger(__name__)
@ -147,13 +146,9 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
t.start()
try:
if xla_extension_version >= 267:
client = xla_client.make_tpu_client( # type: ignore
get_tpu_library_path(),
_options_from_jax_configs("tpu"))
else:
client = xla_client.make_tpu_client(
get_tpu_library_path())
client = xla_client.make_tpu_client( # type: ignore
get_tpu_library_path(),
_options_from_jax_configs("tpu"))
finally:
t.cancel()

@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path):
__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.27"
_minimum_jaxlib_version = "0.4.30"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

@ -58,7 +58,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
@ -707,7 +706,6 @@ class JitTest(jtu.BufferDonationTestCase):
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
self.assertEqual(z2, 1)
@unittest.skipIf(xla_extension_version < 264, "jaxlib version too old")
def test_print_token_buffer_error(self):
token = jax.lax.create_token()
with self.assertRaisesRegex(

@ -30,7 +30,6 @@ from jax._src import op_shardings
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.sharding import common_devices_indices_map
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
@ -1244,8 +1243,6 @@ class ShardingTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(f)(x)
@unittest.skipIf(xla_extension_version < 269,
"Test requires jaxlib 0.4.29 or newer")
def test_make_array_from_single_device_arrays_bad_inputs(self):
x = jnp.arange(10)
mesh = jtu.create_global_mesh((2,), ('x',))

@ -19,16 +19,11 @@ from jax import numpy as jnp
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lax import linalg as lax_linalg
from jax._src.lib import version as jaxlib_version # pylint: disable=g-importing-member
import numpy as np
config.parse_flags_with_absl()
class CholeskyUpdateTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if jaxlib_version < (0, 4, 29):
self.skipTest("Requires jaxlib 0.4.29 or newer")
@jtu.sample_product(
shape=[

@ -62,7 +62,6 @@ from jax.sharding import PartitionSpec as P
from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
config.parse_flags_with_absl()
@ -713,8 +712,6 @@ class CompatTest(bctu.CompatTestBase):
check_results=check_top_k_results)
def test_dynamic_approx_top_k(self):
if xla_client.mlir_api_version < 57:
self.skipTest("Requires newer jaxlib")
# stablehlo.dynamic_approx_top_k is used temporarily for a approx_top_k
# with dynamism
# This is the input that was used to generate the test_data

@ -13,7 +13,6 @@
# limitations under the License.
import os
import unittest
import numpy as np
from absl.testing import absltest, parameterized
@ -31,7 +30,6 @@ from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.extend import ffi
from jax._src.lib import xla_extension_version
jax.config.parse_flags_with_absl()
@ -93,7 +91,6 @@ class RandomTest(jtu.JaxTestCase):
class FfiTest(jtu.JaxTestCase):
@unittest.skipIf(xla_extension_version < 265, "Requires jaxlib 0.4.29")
def testHeadersExist(self):
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
for header in ["c_api.h", "api.h", "ffi.h"]:

@ -26,7 +26,6 @@ from jax import lax
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src import config
from jax._src.lib import xla_extension_version
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.numpy as jnp
from jax.ad_checkpoint import Offloadable, remat, Recompute
@ -68,8 +67,6 @@ def _create_inputs(shape, pspec, mem_kind=None):
class ShardingMemoriesTest(jtu.JaxTestCase):
def setUp(self):
if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
# TODO(b/311021572)
if jtu.is_cloud_tpu():
self.skipTest("Experimental feature not yet implemented on Cloud TPU")

@ -32,7 +32,6 @@ from jax.experimental import profiler as exp_profiler
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from jax._src import compilation_cache as cc
from jax._src.lib import xla_extension_version
import numpy as np
from jax.experimental.serialize_executable import (
@ -55,9 +54,6 @@ class PgleTest(jtu.JaxTestCase):
@unittest.skip("Test failing in CI")
def testPGLEProfilerGetFDOProfile(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
@ -86,9 +82,6 @@ class PgleTest(jtu.JaxTestCase):
self.assertIn(b'custom', fdo_profile)
def testAutoPgle(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
@ -126,9 +119,6 @@ class PgleTest(jtu.JaxTestCase):
self.assertEqual(cache_miss_count[0], 0)
def testAutoPgleWithAot(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
@jax.jit
def f(x):
return x * 2
@ -153,8 +143,6 @@ class PgleTest(jtu.JaxTestCase):
@unittest.skip("Test failing in CI")
def testAutoPgleWithPersistentCache(self):
if xla_extension_version < 268:
return self.skipTest('Requires xla_extension_version >= 268')
@jax.jit
def f(x):

@ -2014,8 +2014,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
expect_error=(
(NotImplementedError, "aggregate_to_topk=False") if (
not agg and (isinstance(k, str) or
isinstance(n, str))) else
(NotImplementedError, "requires jaxlib version") if xla_client.mlir_api_version < 57 and agg and isinstance(k, str) else
isinstance(n, str))) else
None
))
for n in [8, "n"]

@ -42,7 +42,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
from jax._src.lib import xla_extension_version
import jax.numpy as jnp
from jax.experimental.custom_partitioning import custom_partitioning
@ -1706,7 +1705,6 @@ class ShardMapTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"):
f(v)
@unittest.skipIf(xla_extension_version < 262, "Requires jaxlib 0.4.28")
def test_nested_partial_auto(self):
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))

@ -27,7 +27,6 @@ from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
@ -162,14 +161,9 @@ class XlaBridgeTest(jtu.JaxTestCase):
def _mock_tpu_client(library_path=None):
_mock_tpu_client_with_options(library_path=library_path, options=None)
if xla_extension_version >= 267:
with mock.patch.object(xc, "make_tpu_client",
side_effect=_mock_tpu_client_with_options):
xb.tpu_client_timer_callback(0.01)
else:
with mock.patch.object(xc, "make_tpu_client",
side_effect=_mock_tpu_client):
xb.tpu_client_timer_callback(0.01)
with mock.patch.object(xc, "make_tpu_client",
side_effect=_mock_tpu_client_with_options):
xb.tpu_client_timer_callback(0.01)
def test_register_plugin(self):
with self.assertLogs(level="WARNING") as log_output:
@ -205,7 +199,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertTrue(registration.experimental)
options = {}
if xb.get_backend().platform == 'tpu' and xla_extension_version >= 267:
if xb.get_backend().platform == 'tpu':
options["ml_framework_name"] = "JAX"
options["ml_framework_version"] = version.__version__
mock_make.assert_called_once_with("name1", options, None)
@ -243,7 +237,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
"string_option": "string",
"float_option": 1.0,
}
if xb.get_backend().platform == 'tpu' and xla_extension_version >= 267:
if xb.get_backend().platform == 'tpu':
options["ml_framework_name"] = "JAX"
options["ml_framework_version"] = version.__version__