Bump minimum jaxlib version from 0.4.6 to 0.4.7.

Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
This commit is contained in:
Skye Wanderman-Milne 2023-03-28 12:43:32 -07:00
parent 2f105bde2d
commit 00acf459c6
14 changed files with 36 additions and 125 deletions

View File

@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.8 ## jax 0.4.8
* Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
* Deprecations * Deprecations
* CUDA 11.4 support has been dropped. JAX GPU wheels only support * CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built

View File

@ -64,7 +64,6 @@ from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib from jax._src.lib import pmap_lib
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import PmapSharding from jax._src.sharding_impls import PmapSharding
from jax._src.traceback_util import api_boundary from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix, generate_key_paths from jax._src.tree_util import broadcast_prefix, generate_key_paths
@ -2605,8 +2604,6 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
raise ValueError("the shards passed to device_put_sharded must have " raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.") f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape) stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
if xla_extension_version < 139:
xs = [xla.canonicalize_dtype(arg) for arg in xs]
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval) sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
return pxla.batched_device_put( return pxla.batched_device_put(
stacked_aval, PmapSharding(np.array(devices), sharding_spec), stacked_aval, PmapSharding(np.array(devices), sharding_spec),
@ -2656,8 +2653,6 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
assert (isinstance(aval, ShapedArray) and assert (isinstance(aval, ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1) len(xla.aval_to_xla_shapes(aval)) == 1)
sharding_spec = pxla._create_pmap_sharding_spec(aval) sharding_spec = pxla._create_pmap_sharding_spec(aval)
if xla_extension_version < 139:
x = xla.canonicalize_dtype(x)
buf = jax.device_put(x, devices[0]) buf = jax.device_put(x, devices[0])
return pxla.batched_device_put( return pxla.batched_device_put(
aval, PmapSharding(np.array(devices), sharding_spec), aval, PmapSharding(np.array(devices), sharding_spec),

View File

@ -32,7 +32,6 @@ from jax._src import profiler
from jax._src.config import config from jax._src.config import config
from jax._src.util import use_cpp_class, use_cpp_method from jax._src.util import use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import api from jax._src import api
from jax._src.typing import ArrayLike from jax._src.typing import ArrayLike
from jax.interpreters import mlir from jax.interpreters import mlir
@ -341,14 +340,7 @@ class ArrayImpl(basearray.Array):
return _single_device_array_from_buf(arr, committed=False) return _single_device_array_from_buf(arr, committed=False)
return lax_numpy._rewriting_take(self, idx) return lax_numpy._rewriting_take(self, idx)
else: else:
if xla_extension_version >= 144: return lax_numpy._rewriting_take(self, idx)
return lax_numpy._rewriting_take(self, idx)
else:
if (dispatch.is_single_device_sharding(self.sharding) or
self.is_fully_replicated or _is_reduced_on_dim(idx)):
return lax_numpy._rewriting_take(self, idx)
else:
return api.device_put(self._value[idx])
def __iter__(self): def __iter__(self):
if self.ndim == 0: if self.ndim == 0:
@ -404,7 +396,7 @@ class ArrayImpl(basearray.Array):
'named_shape': self.aval.named_shape} 'named_shape': self.aval.named_shape}
return (_reconstruct_array, (fun, args, arr_state, aval_state)) return (_reconstruct_array, (fun, args, arr_state, aval_state))
@use_cpp_method(xla_extension_version >= 138) @use_cpp_method()
def unsafe_buffer_pointer(self): def unsafe_buffer_pointer(self):
if len(self._arrays) != 1: if len(self._arrays) != 1:
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded" raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
@ -412,14 +404,14 @@ class ArrayImpl(basearray.Array):
return self._arrays[0].unsafe_buffer_pointer() return self._arrays[0].unsafe_buffer_pointer()
@property @property
@use_cpp_method(xla_extension_version >= 138) @use_cpp_method()
def __cuda_array_interface__(self): def __cuda_array_interface__(self):
if len(self._arrays) != 1: if len(self._arrays) != 1:
raise ValueError("__cuda_array_interface__() is supported only for " raise ValueError("__cuda_array_interface__() is supported only for "
"unsharded arrays.") "unsharded arrays.")
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
@use_cpp_method(xla_extension_version >= 138) @use_cpp_method()
def on_device_size_in_bytes(self): def on_device_size_in_bytes(self):
"""Returns the total global on-device size of the array in bytes.""" """Returns the total global on-device size of the array in bytes."""
arr = self._arrays[0] arr = self._arrays[0]
@ -495,7 +487,7 @@ class ArrayImpl(basearray.Array):
out.append(Shard(global_d, self.sharding, self.shape, array)) out.append(Shard(global_d, self.sharding, self.shape, array))
return out return out
@use_cpp_method(xla_extension_version >= 138) @use_cpp_method()
def delete(self): def delete(self):
if self._arrays is None: if self._arrays is None:
return return
@ -524,11 +516,11 @@ class ArrayImpl(basearray.Array):
db.block_until_ready() db.block_until_ready()
return self return self
@use_cpp_method(xla_extension_version >= 138) @use_cpp_method()
def _single_device_array_to_np_array(self): def _single_device_array_to_np_array(self):
return np.asarray(self._arrays[0]) return np.asarray(self._arrays[0])
@use_cpp_method(xla_extension_version >= 138) @use_cpp_method()
def _copy_single_device_array_to_host_async(self): def _copy_single_device_array_to_host_async(self):
self._arrays[0].copy_to_host_async() self._arrays[0].copy_to_host_async()
@ -541,10 +533,7 @@ class ArrayImpl(basearray.Array):
return return
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape) copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
for _, arr in copy_plan: for _, arr in copy_plan:
if xla_extension_version >= 140: arr._copy_single_device_array_to_host_async()
arr._copy_single_device_array_to_host_async()
else:
arr.copy_to_host_async()
@property @property
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)") @functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
@ -568,17 +557,11 @@ class ArrayImpl(basearray.Array):
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape) copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
for _, arr in copy_plan: for _, arr in copy_plan:
if xla_extension_version >= 140: arr._copy_single_device_array_to_host_async()
arr._copy_single_device_array_to_host_async()
else:
arr.copy_to_host_async()
npy_value = np.empty(self.shape, self.dtype) npy_value = np.empty(self.shape, self.dtype)
for ind, arr in copy_plan: for ind, arr in copy_plan:
if xla_extension_version >= 140: npy_value[ind] = arr._single_device_array_to_np_array()
npy_value[ind] = arr._single_device_array_to_np_array()
else:
npy_value[ind] = np.asarray(arr)
self._npy_value = npy_value # type: ignore self._npy_value = npy_value # type: ignore
self._npy_value.flags.writeable = False self._npy_value.flags.writeable = False
# https://docs.python.org/3/library/typing.html#typing.cast # https://docs.python.org/3/library/typing.html#typing.cast

View File

@ -25,7 +25,6 @@ from jax._src import lib
from jax._src.lib import jax_jit from jax._src.lib import jax_jit
from jax._src.lib import transfer_guard_lib from jax._src.lib import transfer_guard_lib
from jax._src.lib import xla_client from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -758,8 +757,6 @@ def _update_jax_array_global(val):
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.' 'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable' ' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.') ' jax.config.jax_array.')
if xla_extension_version < 141:
lib.jax_jit.global_state().jax_array = val
def _update_jax_array_thread_local(val): def _update_jax_array_thread_local(val):
if val is not None and not val: if val is not None and not val:
@ -767,8 +764,6 @@ def _update_jax_array_thread_local(val):
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.' 'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable' ' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.') ' jax.config.jax_array.')
if xla_extension_version < 141:
lib.jax_jit.thread_local_state().jax_array = val
jax_array = config.define_bool_state( jax_array = config.define_bool_state(
name='jax_array', name='jax_array',

View File

@ -17,7 +17,6 @@ from jax._src import device_array
from jax._src import array from jax._src import array
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.lib import xla_client from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
SUPPORTED_DTYPES = frozenset({ SUPPORTED_DTYPES = frozenset({
@ -40,24 +39,12 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
undefined behavior if the DLPack consumer writes to a buffer that JAX undefined behavior if the DLPack consumer writes to a buffer that JAX
owns. owns.
""" """
if xla_extension_version >= 140: if not isinstance(x, array.ArrayImpl):
if not isinstance(x, array.ArrayImpl): raise TypeError("Argument to to_dlpack must be a jax.Array, "
raise TypeError("Argument to to_dlpack must be a jax.Array, " f"got {type(x)}")
f"got {type(x)}") assert len(x.devices()) == 1
assert len(x.devices()) == 1 return xla_client._xla.buffer_to_dlpack_managed_tensor(
return xla_client._xla.buffer_to_dlpack_managed_tensor( x.addressable_data(0), take_ownership=take_ownership) # type: ignore
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
else:
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, "
f"got {type(x)}")
if isinstance(x, array.ArrayImpl):
assert len(x._arrays) == 1
buf = x._arrays[0]
else:
buf = x.device_buffer
return xla_client._xla.buffer_to_dlpack_managed_tensor(
buf, take_ownership=take_ownership)
def from_dlpack(dlpack): def from_dlpack(dlpack):
"""Returns a ``DeviceArray`` representation of a DLPack tensor. """Returns a ``DeviceArray`` representation of a DLPack tensor.
@ -80,14 +67,5 @@ def from_dlpack(dlpack):
except RuntimeError: except RuntimeError:
gpu_backend = None gpu_backend = None
if xla_extension_version >= 140: return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer( dlpack, cpu_backend, gpu_backend))
dlpack, cpu_backend, gpu_backend))
else:
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)
if isinstance(buf, array.ArrayImpl):
return jnp.asarray(buf) # asarray ensures dtype canonicalization
else:
return jnp.asarray(array._single_device_array_from_buf(
buf, committed=buf.device() is not None))

View File

@ -63,7 +63,7 @@ from jax._src.lax.utils import (
) )
from jax._src.lib import pytree from jax._src.lib import pytree
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.lib import xla_client, xla_extension_version from jax._src.lib import xla_client
from jax._src.lib.mlir import ir from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
@ -4209,14 +4209,8 @@ def _rng_bit_generator_lowering(
(key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype) (key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype)
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
etype = mlir.dtype_to_ir_type(dtype) etype = mlir.dtype_to_ir_type(dtype)
if ( if dtype in (np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
dtype == np.dtype('uint32') np.dtype('uint64')):
or dtype == np.dtype('uint64')
or (
xla_extension_version >= 140
and (dtype == np.dtype('uint16') or dtype == np.dtype('uint8'))
)
):
rbg_etype = etype rbg_etype = etype
else: else:
rbg_etype = u32_type rbg_etype = u32_type

View File

@ -32,7 +32,6 @@ from jax._src import core
from jax._src import dtypes from jax._src import dtypes
from jax._src import prng from jax._src import prng
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.api import jit, vmap from jax._src.api import jit, vmap
from jax._src.core import NamedShape from jax._src.core import NamedShape
from jax._src.interpreters import ad from jax._src.interpreters import ad
@ -286,7 +285,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
raise TypeError("uniform only accepts 32- or 64-bit dtypes.") raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
rng_bits = nbits rng_bits = nbits
if xla_extension_version >= 140 and nmant < 8: if nmant < 8:
rng_bits = 8 rng_bits = 8
bits = _random_bits(key, rng_bits, shape) bits = _random_bits(key, rng_bits, shape)
uint_dtype = UINT_DTYPES[nbits] uint_dtype = UINT_DTYPES[nbits]

View File

@ -319,15 +319,7 @@ def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
options = None options = None
xla_client.load_pjrt_plugin_dynamically(name, library_path) xla_client.load_pjrt_plugin_dynamically(name, library_path)
if lib.xla_extension_version >= 134: return xla_client.make_c_api_client(name, options)
return xla_client.make_c_api_client(name, options)
else:
if options:
raise ValueError(
'Setting PJRT plugin options through json file requires'
' jaxlib.xla_extension_version >= 134.'
)
return xla_client.make_c_api_client(name)
return factory return factory

View File

@ -1964,16 +1964,10 @@ def _initialize_outfeed_receiver(
device_repr = ", ".join([str(d) for d in devices_with_outfeed]) device_repr = ", ".join([str(d) for d in devices_with_outfeed])
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s", logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
device_repr, max_callback_queue_size_bytes) device_repr, max_callback_queue_size_bytes)
if jaxlib.xla_extension_version > 143: _callback_handler_data.receiver = outfeed_receiver_module.start(
# TODO(phawkins): remove type:ignore after minimum jaxlib version bump _callback_input_received, tuple(clients_with_outfeed),
_callback_handler_data.receiver = outfeed_receiver_module.start( max_callback_queue_size_bytes,
_callback_input_received, tuple(clients_with_outfeed), xb.get_compile_options(1, 1).executable_build_options) # type:ignore
max_callback_queue_size_bytes,
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
else:
_callback_handler_data.receiver = outfeed_receiver_module.start(
_callback_input_received, tuple(clients_with_outfeed),
max_callback_queue_size_bytes)
def exit_handler(): def exit_handler():
# Prevent logging usage during compilation, gives errors under pytest # Prevent logging usage during compilation, gives errors under pytest

View File

@ -129,15 +129,11 @@ def serialize_native(fun_jax: Callable,
assert len(module_kept_var_idx) == len(args_avals) assert len(module_kept_var_idx) == len(args_avals)
mlir_module = compute_dim_vars(mlir_module, args_avals) mlir_module = compute_dim_vars(mlir_module, args_avals)
if xla_client.mlir_api_version >= 46: xla_call_module_version = 4
xla_call_module_version = 4 mlir_str = mlir.module_to_bytecode(mlir_module)
mlir_str = mlir.module_to_bytecode(mlir_module) target_version = stablehlo.get_earliest_forward_compatible_version()
target_version = stablehlo.get_earliest_forward_compatible_version() mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact( mlir_str, target_version)
mlir_str, target_version)
else:
xla_call_module_version = 3
mlir_module_serialized = mlir.module_to_bytecode(mlir_module)
# Figure out the result types and shapes # Figure out the result types and shapes
if "global_out_avals" in lowered.compile_args: if "global_out_avals" in lowered.compile_args:

View File

@ -2631,17 +2631,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
if harness.group_name in require_stablehlo_feature_support: if harness.group_name in require_stablehlo_feature_support:
raise unittest.SkipTest( raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support") "native lowering with shape polymorphism requires additional StableHLO feature support")
# API version 47 supports CHLO ops that decompose into shape dialect ops
if xla_client.mlir_api_version < 47:
require_stablehlo_feature_support_shape_dialect = {
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
"vmap_nextafter_broadcasting", "vmap_sinh"
}
if harness.group_name in require_stablehlo_feature_support_shape_dialect:
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
if (jtu.device_under_test() == "tpu" and if (jtu.device_under_test() == "tpu" and
harness.fullname in [ harness.fullname in [
"jnp.cumsum_reduce_axis=poly", "jnp.cumsum_reduce_axis=poly",

View File

@ -16,7 +16,7 @@
# eval()-ed by setup.py, so it should not have any dependencies. # eval()-ed by setup.py, so it should not have any dependencies.
__version__ = "0.4.8" __version__ = "0.4.8"
_minimum_jaxlib_version = "0.4.6" _minimum_jaxlib_version = "0.4.7"
def _version_as_tuple(version_str): def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit()) return tuple(int(i) for i in version_str.split(".") if i.isdigit())

View File

@ -19,9 +19,7 @@ from absl.testing import parameterized
import jax import jax
from jax._src import core from jax._src import core
from jax._src import api_util from jax._src import api_util
from jax._src.interpreters import xla
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax._src.lib import xla_extension_version
from jax import dtypes from jax import dtypes
from jax._src import lib as jaxlib from jax._src import lib as jaxlib
from jax import numpy as jnp from jax import numpy as jnp
@ -33,8 +31,6 @@ config.parse_flags_with_absl()
def _cpp_device_put(value, device): def _cpp_device_put(value, device):
aval = api_util.shaped_abstractify(value) aval = api_util.shaped_abstractify(value)
if xla_extension_version < 139:
value = xla.canonicalize_dtype(value)
return pxla.batched_device_put( return pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [value], [device]) aval, jax.sharding.SingleDeviceSharding(device), [value], [device])

View File

@ -39,7 +39,6 @@ from jax.interpreters import xla
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
from jax.interpreters import batching from jax.interpreters import batching
from jax._src import array from jax._src import array
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import hlo
from jax._src import dtypes from jax._src import dtypes
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
@ -2855,8 +2854,6 @@ class FooTyRules:
buf, = arr._arrays buf, = arr._arrays
else: else:
buf, = arr buf, = arr
if xla_extension_version < 140:
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return FooArray(aval.shape, buf) return FooArray(aval.shape, buf)
return handler return handler