Bump minimum jaxlib version from 0.4.6 to 0.4.7.

Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
This commit is contained in:
Skye Wanderman-Milne 2023-03-28 12:43:32 -07:00
parent 2f105bde2d
commit 00acf459c6
14 changed files with 36 additions and 125 deletions

View File

@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.8
* Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
* Deprecations
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built

View File

@ -64,7 +64,6 @@ from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix, generate_key_paths
@ -2605,8 +2604,6 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
if xla_extension_version < 139:
xs = [xla.canonicalize_dtype(arg) for arg in xs]
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
return pxla.batched_device_put(
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
@ -2656,8 +2653,6 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
assert (isinstance(aval, ShapedArray) and
len(xla.aval_to_xla_shapes(aval)) == 1)
sharding_spec = pxla._create_pmap_sharding_spec(aval)
if xla_extension_version < 139:
x = xla.canonicalize_dtype(x)
buf = jax.device_put(x, devices[0])
return pxla.batched_device_put(
aval, PmapSharding(np.array(devices), sharding_spec),

View File

@ -32,7 +32,6 @@ from jax._src import profiler
from jax._src.config import config
from jax._src.util import use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import api
from jax._src.typing import ArrayLike
from jax.interpreters import mlir
@ -341,14 +340,7 @@ class ArrayImpl(basearray.Array):
return _single_device_array_from_buf(arr, committed=False)
return lax_numpy._rewriting_take(self, idx)
else:
if xla_extension_version >= 144:
return lax_numpy._rewriting_take(self, idx)
else:
if (dispatch.is_single_device_sharding(self.sharding) or
self.is_fully_replicated or _is_reduced_on_dim(idx)):
return lax_numpy._rewriting_take(self, idx)
else:
return api.device_put(self._value[idx])
def __iter__(self):
if self.ndim == 0:
@ -404,7 +396,7 @@ class ArrayImpl(basearray.Array):
'named_shape': self.aval.named_shape}
return (_reconstruct_array, (fun, args, arr_state, aval_state))
@use_cpp_method(xla_extension_version >= 138)
@use_cpp_method()
def unsafe_buffer_pointer(self):
if len(self._arrays) != 1:
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
@ -412,14 +404,14 @@ class ArrayImpl(basearray.Array):
return self._arrays[0].unsafe_buffer_pointer()
@property
@use_cpp_method(xla_extension_version >= 138)
@use_cpp_method()
def __cuda_array_interface__(self):
if len(self._arrays) != 1:
raise ValueError("__cuda_array_interface__() is supported only for "
"unsharded arrays.")
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
@use_cpp_method(xla_extension_version >= 138)
@use_cpp_method()
def on_device_size_in_bytes(self):
"""Returns the total global on-device size of the array in bytes."""
arr = self._arrays[0]
@ -495,7 +487,7 @@ class ArrayImpl(basearray.Array):
out.append(Shard(global_d, self.sharding, self.shape, array))
return out
@use_cpp_method(xla_extension_version >= 138)
@use_cpp_method()
def delete(self):
if self._arrays is None:
return
@ -524,11 +516,11 @@ class ArrayImpl(basearray.Array):
db.block_until_ready()
return self
@use_cpp_method(xla_extension_version >= 138)
@use_cpp_method()
def _single_device_array_to_np_array(self):
return np.asarray(self._arrays[0])
@use_cpp_method(xla_extension_version >= 138)
@use_cpp_method()
def _copy_single_device_array_to_host_async(self):
self._arrays[0].copy_to_host_async()
@ -541,10 +533,7 @@ class ArrayImpl(basearray.Array):
return
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
for _, arr in copy_plan:
if xla_extension_version >= 140:
arr._copy_single_device_array_to_host_async()
else:
arr.copy_to_host_async()
@property
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
@ -568,17 +557,11 @@ class ArrayImpl(basearray.Array):
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
for _, arr in copy_plan:
if xla_extension_version >= 140:
arr._copy_single_device_array_to_host_async()
else:
arr.copy_to_host_async()
npy_value = np.empty(self.shape, self.dtype)
for ind, arr in copy_plan:
if xla_extension_version >= 140:
npy_value[ind] = arr._single_device_array_to_np_array()
else:
npy_value[ind] = np.asarray(arr)
self._npy_value = npy_value # type: ignore
self._npy_value.flags.writeable = False
# https://docs.python.org/3/library/typing.html#typing.cast

View File

@ -25,7 +25,6 @@ from jax._src import lib
from jax._src.lib import jax_jit
from jax._src.lib import transfer_guard_lib
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
logger = logging.getLogger(__name__)
@ -758,8 +757,6 @@ def _update_jax_array_global(val):
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
if xla_extension_version < 141:
lib.jax_jit.global_state().jax_array = val
def _update_jax_array_thread_local(val):
if val is not None and not val:
@ -767,8 +764,6 @@ def _update_jax_array_thread_local(val):
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
if xla_extension_version < 141:
lib.jax_jit.thread_local_state().jax_array = val
jax_array = config.define_bool_state(
name='jax_array',

View File

@ -17,7 +17,6 @@ from jax._src import device_array
from jax._src import array
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
SUPPORTED_DTYPES = frozenset({
@ -40,24 +39,12 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
undefined behavior if the DLPack consumer writes to a buffer that JAX
owns.
"""
if xla_extension_version >= 140:
if not isinstance(x, array.ArrayImpl):
raise TypeError("Argument to to_dlpack must be a jax.Array, "
f"got {type(x)}")
assert len(x.devices()) == 1
return xla_client._xla.buffer_to_dlpack_managed_tensor(
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
else:
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, "
f"got {type(x)}")
if isinstance(x, array.ArrayImpl):
assert len(x._arrays) == 1
buf = x._arrays[0]
else:
buf = x.device_buffer
return xla_client._xla.buffer_to_dlpack_managed_tensor(
buf, take_ownership=take_ownership)
def from_dlpack(dlpack):
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
@ -80,14 +67,5 @@ def from_dlpack(dlpack):
except RuntimeError:
gpu_backend = None
if xla_extension_version >= 140:
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend))
else:
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)
if isinstance(buf, array.ArrayImpl):
return jnp.asarray(buf) # asarray ensures dtype canonicalization
else:
return jnp.asarray(array._single_device_array_from_buf(
buf, committed=buf.device() is not None))

View File

@ -63,7 +63,7 @@ from jax._src.lax.utils import (
)
from jax._src.lib import pytree
from jax._src import xla_bridge
from jax._src.lib import xla_client, xla_extension_version
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@ -4209,14 +4209,8 @@ def _rng_bit_generator_lowering(
(key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype)
dtype = np.dtype(dtype)
etype = mlir.dtype_to_ir_type(dtype)
if (
dtype == np.dtype('uint32')
or dtype == np.dtype('uint64')
or (
xla_extension_version >= 140
and (dtype == np.dtype('uint16') or dtype == np.dtype('uint8'))
)
):
if dtype in (np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
np.dtype('uint64')):
rbg_etype = etype
else:
rbg_etype = u32_type

View File

@ -32,7 +32,6 @@ from jax._src import core
from jax._src import dtypes
from jax._src import prng
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.api import jit, vmap
from jax._src.core import NamedShape
from jax._src.interpreters import ad
@ -286,7 +285,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
rng_bits = nbits
if xla_extension_version >= 140 and nmant < 8:
if nmant < 8:
rng_bits = 8
bits = _random_bits(key, rng_bits, shape)
uint_dtype = UINT_DTYPES[nbits]

View File

@ -319,15 +319,7 @@ def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
options = None
xla_client.load_pjrt_plugin_dynamically(name, library_path)
if lib.xla_extension_version >= 134:
return xla_client.make_c_api_client(name, options)
else:
if options:
raise ValueError(
'Setting PJRT plugin options through json file requires'
' jaxlib.xla_extension_version >= 134.'
)
return xla_client.make_c_api_client(name)
return factory

View File

@ -1964,16 +1964,10 @@ def _initialize_outfeed_receiver(
device_repr = ", ".join([str(d) for d in devices_with_outfeed])
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
device_repr, max_callback_queue_size_bytes)
if jaxlib.xla_extension_version > 143:
# TODO(phawkins): remove type:ignore after minimum jaxlib version bump
_callback_handler_data.receiver = outfeed_receiver_module.start(
_callback_input_received, tuple(clients_with_outfeed),
max_callback_queue_size_bytes,
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
else:
_callback_handler_data.receiver = outfeed_receiver_module.start(
_callback_input_received, tuple(clients_with_outfeed),
max_callback_queue_size_bytes)
def exit_handler():
# Prevent logging usage during compilation, gives errors under pytest

View File

@ -129,15 +129,11 @@ def serialize_native(fun_jax: Callable,
assert len(module_kept_var_idx) == len(args_avals)
mlir_module = compute_dim_vars(mlir_module, args_avals)
if xla_client.mlir_api_version >= 46:
xla_call_module_version = 4
mlir_str = mlir.module_to_bytecode(mlir_module)
target_version = stablehlo.get_earliest_forward_compatible_version()
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
mlir_str, target_version)
else:
xla_call_module_version = 3
mlir_module_serialized = mlir.module_to_bytecode(mlir_module)
# Figure out the result types and shapes
if "global_out_avals" in lowered.compile_args:

View File

@ -2631,17 +2631,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
if harness.group_name in require_stablehlo_feature_support:
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
# API version 47 supports CHLO ops that decompose into shape dialect ops
if xla_client.mlir_api_version < 47:
require_stablehlo_feature_support_shape_dialect = {
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
"vmap_nextafter_broadcasting", "vmap_sinh"
}
if harness.group_name in require_stablehlo_feature_support_shape_dialect:
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
if (jtu.device_under_test() == "tpu" and
harness.fullname in [
"jnp.cumsum_reduce_axis=poly",

View File

@ -16,7 +16,7 @@
# eval()-ed by setup.py, so it should not have any dependencies.
__version__ = "0.4.8"
_minimum_jaxlib_version = "0.4.6"
_minimum_jaxlib_version = "0.4.7"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

View File

@ -19,9 +19,7 @@ from absl.testing import parameterized
import jax
from jax._src import core
from jax._src import api_util
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.lib import xla_extension_version
from jax import dtypes
from jax._src import lib as jaxlib
from jax import numpy as jnp
@ -33,8 +31,6 @@ config.parse_flags_with_absl()
def _cpp_device_put(value, device):
aval = api_util.shaped_abstractify(value)
if xla_extension_version < 139:
value = xla.canonicalize_dtype(value)
return pxla.batched_device_put(
aval, jax.sharding.SingleDeviceSharding(device), [value], [device])

View File

@ -39,7 +39,6 @@ from jax.interpreters import xla
from jax._src.interpreters import mlir
from jax.interpreters import batching
from jax._src import array
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir.dialects import hlo
from jax._src import dtypes
from jax._src.interpreters import pxla
@ -2855,8 +2854,6 @@ class FooTyRules:
buf, = arr._arrays
else:
buf, = arr
if xla_extension_version < 140:
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
return FooArray(aval.shape, buf)
return handler