diff --git a/CHANGELOG.md b/CHANGELOG.md index dfdd9e30c..d95cac920 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.8 +* Changes + * The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7. + * Deprecations * CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built diff --git a/jax/_src/api.py b/jax/_src/api.py index 8dadd1422..60ea338b6 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -64,7 +64,6 @@ from jax._src.lax import lax as lax_internal from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib -from jax._src.lib import xla_extension_version from jax._src.sharding_impls import PmapSharding from jax._src.traceback_util import api_boundary from jax._src.tree_util import broadcast_prefix, generate_key_paths @@ -2605,8 +2604,6 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # raise ValueError("the shards passed to device_put_sharded must have " f"consistent shape and dtype, but got {a1} and {a2}.") stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) - if xla_extension_version < 139: - xs = [xla.canonicalize_dtype(arg) for arg in xs] sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval) return pxla.batched_device_put( stacked_aval, PmapSharding(np.array(devices), sharding_spec), @@ -2656,8 +2653,6 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811 assert (isinstance(aval, ShapedArray) and len(xla.aval_to_xla_shapes(aval)) == 1) sharding_spec = pxla._create_pmap_sharding_spec(aval) - if xla_extension_version < 139: - x = xla.canonicalize_dtype(x) buf = jax.device_put(x, devices[0]) return pxla.batched_device_put( aval, PmapSharding(np.array(devices), sharding_spec), diff --git a/jax/_src/array.py b/jax/_src/array.py index 2e92e768a..a995d4f11 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -32,7 +32,6 @@ from jax._src import profiler from jax._src.config import config from jax._src.util import use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src import api from jax._src.typing import ArrayLike from jax.interpreters import mlir @@ -341,14 +340,7 @@ class ArrayImpl(basearray.Array): return _single_device_array_from_buf(arr, committed=False) return lax_numpy._rewriting_take(self, idx) else: - if xla_extension_version >= 144: - return lax_numpy._rewriting_take(self, idx) - else: - if (dispatch.is_single_device_sharding(self.sharding) or - self.is_fully_replicated or _is_reduced_on_dim(idx)): - return lax_numpy._rewriting_take(self, idx) - else: - return api.device_put(self._value[idx]) + return lax_numpy._rewriting_take(self, idx) def __iter__(self): if self.ndim == 0: @@ -404,7 +396,7 @@ class ArrayImpl(basearray.Array): 'named_shape': self.aval.named_shape} return (_reconstruct_array, (fun, args, arr_state, aval_state)) - @use_cpp_method(xla_extension_version >= 138) + @use_cpp_method() def unsafe_buffer_pointer(self): if len(self._arrays) != 1: raise ValueError("unsafe_buffer_pointer() is supported only for unsharded" @@ -412,14 +404,14 @@ class ArrayImpl(basearray.Array): return self._arrays[0].unsafe_buffer_pointer() @property - @use_cpp_method(xla_extension_version >= 138) + @use_cpp_method() def __cuda_array_interface__(self): if len(self._arrays) != 1: raise ValueError("__cuda_array_interface__() is supported only for " "unsharded arrays.") return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties - @use_cpp_method(xla_extension_version >= 138) + @use_cpp_method() def on_device_size_in_bytes(self): """Returns the total global on-device size of the array in bytes.""" arr = self._arrays[0] @@ -495,7 +487,7 @@ class ArrayImpl(basearray.Array): out.append(Shard(global_d, self.sharding, self.shape, array)) return out - @use_cpp_method(xla_extension_version >= 138) + @use_cpp_method() def delete(self): if self._arrays is None: return @@ -524,11 +516,11 @@ class ArrayImpl(basearray.Array): db.block_until_ready() return self - @use_cpp_method(xla_extension_version >= 138) + @use_cpp_method() def _single_device_array_to_np_array(self): return np.asarray(self._arrays[0]) - @use_cpp_method(xla_extension_version >= 138) + @use_cpp_method() def _copy_single_device_array_to_host_async(self): self._arrays[0].copy_to_host_async() @@ -541,10 +533,7 @@ class ArrayImpl(basearray.Array): return copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape) for _, arr in copy_plan: - if xla_extension_version >= 140: - arr._copy_single_device_array_to_host_async() - else: - arr.copy_to_host_async() + arr._copy_single_device_array_to_host_async() @property @functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)") @@ -568,17 +557,11 @@ class ArrayImpl(basearray.Array): copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape) for _, arr in copy_plan: - if xla_extension_version >= 140: - arr._copy_single_device_array_to_host_async() - else: - arr.copy_to_host_async() + arr._copy_single_device_array_to_host_async() npy_value = np.empty(self.shape, self.dtype) for ind, arr in copy_plan: - if xla_extension_version >= 140: - npy_value[ind] = arr._single_device_array_to_np_array() - else: - npy_value[ind] = np.asarray(arr) + npy_value[ind] = arr._single_device_array_to_np_array() self._npy_value = npy_value # type: ignore self._npy_value.flags.writeable = False # https://docs.python.org/3/library/typing.html#typing.cast diff --git a/jax/_src/config.py b/jax/_src/config.py index 9203ab3cc..fb7fdb5ba 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -25,7 +25,6 @@ from jax._src import lib from jax._src.lib import jax_jit from jax._src.lib import transfer_guard_lib from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version logger = logging.getLogger(__name__) @@ -758,8 +757,6 @@ def _update_jax_array_global(val): 'jax.config.jax_array cannot be disabled after jax 0.4.7 release.' ' Please downgrade to jax and jaxlib 0.4.6 if you want to disable' ' jax.config.jax_array.') - if xla_extension_version < 141: - lib.jax_jit.global_state().jax_array = val def _update_jax_array_thread_local(val): if val is not None and not val: @@ -767,8 +764,6 @@ def _update_jax_array_thread_local(val): 'jax.config.jax_array cannot be disabled after jax 0.4.7 release.' ' Please downgrade to jax and jaxlib 0.4.6 if you want to disable' ' jax.config.jax_array.') - if xla_extension_version < 141: - lib.jax_jit.thread_local_state().jax_array = val jax_array = config.define_bool_state( name='jax_array', diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 510d39e44..bba303675 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -17,7 +17,6 @@ from jax._src import device_array from jax._src import array from jax._src import xla_bridge from jax._src.lib import xla_client -from jax._src.lib import xla_extension_version SUPPORTED_DTYPES = frozenset({ @@ -40,24 +39,12 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False) undefined behavior if the DLPack consumer writes to a buffer that JAX owns. """ - if xla_extension_version >= 140: - if not isinstance(x, array.ArrayImpl): - raise TypeError("Argument to to_dlpack must be a jax.Array, " - f"got {type(x)}") - assert len(x.devices()) == 1 - return xla_client._xla.buffer_to_dlpack_managed_tensor( - x.addressable_data(0), take_ownership=take_ownership) # type: ignore - else: - if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)): - raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, " - f"got {type(x)}") - if isinstance(x, array.ArrayImpl): - assert len(x._arrays) == 1 - buf = x._arrays[0] - else: - buf = x.device_buffer - return xla_client._xla.buffer_to_dlpack_managed_tensor( - buf, take_ownership=take_ownership) + if not isinstance(x, array.ArrayImpl): + raise TypeError("Argument to to_dlpack must be a jax.Array, " + f"got {type(x)}") + assert len(x.devices()) == 1 + return xla_client._xla.buffer_to_dlpack_managed_tensor( + x.addressable_data(0), take_ownership=take_ownership) # type: ignore def from_dlpack(dlpack): """Returns a ``DeviceArray`` representation of a DLPack tensor. @@ -80,14 +67,5 @@ def from_dlpack(dlpack): except RuntimeError: gpu_backend = None - if xla_extension_version >= 140: - return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, cpu_backend, gpu_backend)) - else: - buf = xla_client._xla.dlpack_managed_tensor_to_buffer( - dlpack, cpu_backend, gpu_backend) - if isinstance(buf, array.ArrayImpl): - return jnp.asarray(buf) # asarray ensures dtype canonicalization - else: - return jnp.asarray(array._single_device_array_from_buf( - buf, committed=buf.device() is not None)) + return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( + dlpack, cpu_backend, gpu_backend)) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4099562e7..df988a04b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -63,7 +63,7 @@ from jax._src.lax.utils import ( ) from jax._src.lib import pytree from jax._src import xla_bridge -from jax._src.lib import xla_client, xla_extension_version +from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -4209,14 +4209,8 @@ def _rng_bit_generator_lowering( (key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype) dtype = np.dtype(dtype) etype = mlir.dtype_to_ir_type(dtype) - if ( - dtype == np.dtype('uint32') - or dtype == np.dtype('uint64') - or ( - xla_extension_version >= 140 - and (dtype == np.dtype('uint16') or dtype == np.dtype('uint8')) - ) - ): + if dtype in (np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), + np.dtype('uint64')): rbg_etype = etype else: rbg_etype = u32_type diff --git a/jax/_src/random.py b/jax/_src/random.py index e2799dda3..d8238c4e5 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -32,7 +32,6 @@ from jax._src import core from jax._src import dtypes from jax._src import prng from jax._src import xla_bridge -from jax._src.lib import xla_extension_version from jax._src.api import jit, vmap from jax._src.core import NamedShape from jax._src.interpreters import ad @@ -286,7 +285,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array: raise TypeError("uniform only accepts 32- or 64-bit dtypes.") rng_bits = nbits - if xla_extension_version >= 140 and nmant < 8: + if nmant < 8: rng_bits = 8 bits = _random_bits(key, rng_bits, shape) uint_dtype = UINT_DTYPES[nbits] diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 9aed9638d..a081227e9 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -319,15 +319,7 @@ def register_pjrt_plugin_factories(plugins_from_env: str) -> None: options = None xla_client.load_pjrt_plugin_dynamically(name, library_path) - if lib.xla_extension_version >= 134: - return xla_client.make_c_api_client(name, options) - else: - if options: - raise ValueError( - 'Setting PJRT plugin options through json file requires' - ' jaxlib.xla_extension_version >= 134.' - ) - return xla_client.make_c_api_client(name) + return xla_client.make_c_api_client(name, options) return factory diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 64a1a88cc..e69b2a638 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -1964,16 +1964,10 @@ def _initialize_outfeed_receiver( device_repr = ", ".join([str(d) for d in devices_with_outfeed]) logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s", device_repr, max_callback_queue_size_bytes) - if jaxlib.xla_extension_version > 143: - # TODO(phawkins): remove type:ignore after minimum jaxlib version bump - _callback_handler_data.receiver = outfeed_receiver_module.start( - _callback_input_received, tuple(clients_with_outfeed), - max_callback_queue_size_bytes, - xb.get_compile_options(1, 1).executable_build_options) # type:ignore - else: - _callback_handler_data.receiver = outfeed_receiver_module.start( - _callback_input_received, tuple(clients_with_outfeed), - max_callback_queue_size_bytes) + _callback_handler_data.receiver = outfeed_receiver_module.start( + _callback_input_received, tuple(clients_with_outfeed), + max_callback_queue_size_bytes, + xb.get_compile_options(1, 1).executable_build_options) # type:ignore def exit_handler(): # Prevent logging usage during compilation, gives errors under pytest diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 2fd43a23b..9b9073e40 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -129,15 +129,11 @@ def serialize_native(fun_jax: Callable, assert len(module_kept_var_idx) == len(args_avals) mlir_module = compute_dim_vars(mlir_module, args_avals) - if xla_client.mlir_api_version >= 46: - xla_call_module_version = 4 - mlir_str = mlir.module_to_bytecode(mlir_module) - target_version = stablehlo.get_earliest_forward_compatible_version() - mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact( - mlir_str, target_version) - else: - xla_call_module_version = 3 - mlir_module_serialized = mlir.module_to_bytecode(mlir_module) + xla_call_module_version = 4 + mlir_str = mlir.module_to_bytecode(mlir_module) + target_version = stablehlo.get_earliest_forward_compatible_version() + mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact( + mlir_str, target_version) # Figure out the result types and shapes if "global_out_avals" in lowered.compile_args: diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 452e3bb2f..a543c7690 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2631,17 +2631,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): if harness.group_name in require_stablehlo_feature_support: raise unittest.SkipTest( "native lowering with shape polymorphism requires additional StableHLO feature support") - # API version 47 supports CHLO ops that decompose into shape dialect ops - if xla_client.mlir_api_version < 47: - require_stablehlo_feature_support_shape_dialect = { - "vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh", - "vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf", - "vmap_erfc", "vmap_lgamma", "vmap_nextafter", - "vmap_nextafter_broadcasting", "vmap_sinh" - } - if harness.group_name in require_stablehlo_feature_support_shape_dialect: - raise unittest.SkipTest( - "native lowering with shape polymorphism requires additional StableHLO feature support") if (jtu.device_under_test() == "tpu" and harness.fullname in [ "jnp.cumsum_reduce_axis=poly", diff --git a/jax/version.py b/jax/version.py index 8b8b051e8..b96927a4c 100644 --- a/jax/version.py +++ b/jax/version.py @@ -16,7 +16,7 @@ # eval()-ed by setup.py, so it should not have any dependencies. __version__ = "0.4.8" -_minimum_jaxlib_version = "0.4.6" +_minimum_jaxlib_version = "0.4.7" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index c61ed6735..6ebcffc62 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -19,9 +19,7 @@ from absl.testing import parameterized import jax from jax._src import core from jax._src import api_util -from jax._src.interpreters import xla from jax._src.interpreters import pxla -from jax._src.lib import xla_extension_version from jax import dtypes from jax._src import lib as jaxlib from jax import numpy as jnp @@ -33,8 +31,6 @@ config.parse_flags_with_absl() def _cpp_device_put(value, device): aval = api_util.shaped_abstractify(value) - if xla_extension_version < 139: - value = xla.canonicalize_dtype(value) return pxla.batched_device_put( aval, jax.sharding.SingleDeviceSharding(device), [value], [device]) diff --git a/tests/lax_test.py b/tests/lax_test.py index 826335721..80361fcea 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -39,7 +39,6 @@ from jax.interpreters import xla from jax._src.interpreters import mlir from jax.interpreters import batching from jax._src import array -from jax._src.lib import xla_extension_version from jax._src.lib.mlir.dialects import hlo from jax._src import dtypes from jax._src.interpreters import pxla @@ -2855,8 +2854,6 @@ class FooTyRules: buf, = arr._arrays else: buf, = arr - if xla_extension_version < 140: - buf.aval = core.ShapedArray(buf.shape, buf.dtype) return FooArray(aval.shape, buf) return handler