mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has xla_extension_version 144 and mlir_api_version 47)
This commit is contained in:
parent
2f105bde2d
commit
00acf459c6
@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.8
|
||||
|
||||
* Changes
|
||||
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
|
||||
|
||||
* Deprecations
|
||||
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
|
||||
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
|
||||
|
@ -64,7 +64,6 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix, generate_key_paths
|
||||
@ -2605,8 +2604,6 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
raise ValueError("the shards passed to device_put_sharded must have "
|
||||
f"consistent shape and dtype, but got {a1} and {a2}.")
|
||||
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
||||
if xla_extension_version < 139:
|
||||
xs = [xla.canonicalize_dtype(arg) for arg in xs]
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
return pxla.batched_device_put(
|
||||
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
|
||||
@ -2656,8 +2653,6 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
assert (isinstance(aval, ShapedArray) and
|
||||
len(xla.aval_to_xla_shapes(aval)) == 1)
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||
if xla_extension_version < 139:
|
||||
x = xla.canonicalize_dtype(x)
|
||||
buf = jax.device_put(x, devices[0])
|
||||
return pxla.batched_device_put(
|
||||
aval, PmapSharding(np.array(devices), sharding_spec),
|
||||
|
@ -32,7 +32,6 @@ from jax._src import profiler
|
||||
from jax._src.config import config
|
||||
from jax._src.util import use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import api
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax.interpreters import mlir
|
||||
@ -341,14 +340,7 @@ class ArrayImpl(basearray.Array):
|
||||
return _single_device_array_from_buf(arr, committed=False)
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
else:
|
||||
if xla_extension_version >= 144:
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
else:
|
||||
if (dispatch.is_single_device_sharding(self.sharding) or
|
||||
self.is_fully_replicated or _is_reduced_on_dim(idx)):
|
||||
return lax_numpy._rewriting_take(self, idx)
|
||||
else:
|
||||
return api.device_put(self._value[idx])
|
||||
|
||||
def __iter__(self):
|
||||
if self.ndim == 0:
|
||||
@ -404,7 +396,7 @@ class ArrayImpl(basearray.Array):
|
||||
'named_shape': self.aval.named_shape}
|
||||
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
||||
|
||||
@use_cpp_method(xla_extension_version >= 138)
|
||||
@use_cpp_method()
|
||||
def unsafe_buffer_pointer(self):
|
||||
if len(self._arrays) != 1:
|
||||
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
|
||||
@ -412,14 +404,14 @@ class ArrayImpl(basearray.Array):
|
||||
return self._arrays[0].unsafe_buffer_pointer()
|
||||
|
||||
@property
|
||||
@use_cpp_method(xla_extension_version >= 138)
|
||||
@use_cpp_method()
|
||||
def __cuda_array_interface__(self):
|
||||
if len(self._arrays) != 1:
|
||||
raise ValueError("__cuda_array_interface__() is supported only for "
|
||||
"unsharded arrays.")
|
||||
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
|
||||
|
||||
@use_cpp_method(xla_extension_version >= 138)
|
||||
@use_cpp_method()
|
||||
def on_device_size_in_bytes(self):
|
||||
"""Returns the total global on-device size of the array in bytes."""
|
||||
arr = self._arrays[0]
|
||||
@ -495,7 +487,7 @@ class ArrayImpl(basearray.Array):
|
||||
out.append(Shard(global_d, self.sharding, self.shape, array))
|
||||
return out
|
||||
|
||||
@use_cpp_method(xla_extension_version >= 138)
|
||||
@use_cpp_method()
|
||||
def delete(self):
|
||||
if self._arrays is None:
|
||||
return
|
||||
@ -524,11 +516,11 @@ class ArrayImpl(basearray.Array):
|
||||
db.block_until_ready()
|
||||
return self
|
||||
|
||||
@use_cpp_method(xla_extension_version >= 138)
|
||||
@use_cpp_method()
|
||||
def _single_device_array_to_np_array(self):
|
||||
return np.asarray(self._arrays[0])
|
||||
|
||||
@use_cpp_method(xla_extension_version >= 138)
|
||||
@use_cpp_method()
|
||||
def _copy_single_device_array_to_host_async(self):
|
||||
self._arrays[0].copy_to_host_async()
|
||||
|
||||
@ -541,10 +533,7 @@ class ArrayImpl(basearray.Array):
|
||||
return
|
||||
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
||||
for _, arr in copy_plan:
|
||||
if xla_extension_version >= 140:
|
||||
arr._copy_single_device_array_to_host_async()
|
||||
else:
|
||||
arr.copy_to_host_async()
|
||||
|
||||
@property
|
||||
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
|
||||
@ -568,17 +557,11 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
||||
for _, arr in copy_plan:
|
||||
if xla_extension_version >= 140:
|
||||
arr._copy_single_device_array_to_host_async()
|
||||
else:
|
||||
arr.copy_to_host_async()
|
||||
|
||||
npy_value = np.empty(self.shape, self.dtype)
|
||||
for ind, arr in copy_plan:
|
||||
if xla_extension_version >= 140:
|
||||
npy_value[ind] = arr._single_device_array_to_np_array()
|
||||
else:
|
||||
npy_value[ind] = np.asarray(arr)
|
||||
self._npy_value = npy_value # type: ignore
|
||||
self._npy_value.flags.writeable = False
|
||||
# https://docs.python.org/3/library/typing.html#typing.cast
|
||||
|
@ -25,7 +25,6 @@ from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import transfer_guard_lib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -758,8 +757,6 @@ def _update_jax_array_global(val):
|
||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||
' jax.config.jax_array.')
|
||||
if xla_extension_version < 141:
|
||||
lib.jax_jit.global_state().jax_array = val
|
||||
|
||||
def _update_jax_array_thread_local(val):
|
||||
if val is not None and not val:
|
||||
@ -767,8 +764,6 @@ def _update_jax_array_thread_local(val):
|
||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||
' jax.config.jax_array.')
|
||||
if xla_extension_version < 141:
|
||||
lib.jax_jit.thread_local_state().jax_array = val
|
||||
|
||||
jax_array = config.define_bool_state(
|
||||
name='jax_array',
|
||||
|
@ -17,7 +17,6 @@ from jax._src import device_array
|
||||
from jax._src import array
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
SUPPORTED_DTYPES = frozenset({
|
||||
@ -40,24 +39,12 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
|
||||
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
||||
owns.
|
||||
"""
|
||||
if xla_extension_version >= 140:
|
||||
if not isinstance(x, array.ArrayImpl):
|
||||
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
||||
f"got {type(x)}")
|
||||
assert len(x.devices()) == 1
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
|
||||
else:
|
||||
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, "
|
||||
f"got {type(x)}")
|
||||
if isinstance(x, array.ArrayImpl):
|
||||
assert len(x._arrays) == 1
|
||||
buf = x._arrays[0]
|
||||
else:
|
||||
buf = x.device_buffer
|
||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||
buf, take_ownership=take_ownership)
|
||||
|
||||
def from_dlpack(dlpack):
|
||||
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
|
||||
@ -80,14 +67,5 @@ def from_dlpack(dlpack):
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
if xla_extension_version >= 140:
|
||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend))
|
||||
else:
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)
|
||||
if isinstance(buf, array.ArrayImpl):
|
||||
return jnp.asarray(buf) # asarray ensures dtype canonicalization
|
||||
else:
|
||||
return jnp.asarray(array._single_device_array_from_buf(
|
||||
buf, committed=buf.device() is not None))
|
||||
|
@ -63,7 +63,7 @@ from jax._src.lax.utils import (
|
||||
)
|
||||
from jax._src.lib import pytree
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client, xla_extension_version
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -4209,14 +4209,8 @@ def _rng_bit_generator_lowering(
|
||||
(key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype)
|
||||
dtype = np.dtype(dtype)
|
||||
etype = mlir.dtype_to_ir_type(dtype)
|
||||
if (
|
||||
dtype == np.dtype('uint32')
|
||||
or dtype == np.dtype('uint64')
|
||||
or (
|
||||
xla_extension_version >= 140
|
||||
and (dtype == np.dtype('uint16') or dtype == np.dtype('uint8'))
|
||||
)
|
||||
):
|
||||
if dtype in (np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
|
||||
np.dtype('uint64')):
|
||||
rbg_etype = etype
|
||||
else:
|
||||
rbg_etype = u32_type
|
||||
|
@ -32,7 +32,6 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import prng
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.core import NamedShape
|
||||
from jax._src.interpreters import ad
|
||||
@ -286,7 +285,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
|
||||
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
|
||||
|
||||
rng_bits = nbits
|
||||
if xla_extension_version >= 140 and nmant < 8:
|
||||
if nmant < 8:
|
||||
rng_bits = 8
|
||||
bits = _random_bits(key, rng_bits, shape)
|
||||
uint_dtype = UINT_DTYPES[nbits]
|
||||
|
@ -319,15 +319,7 @@ def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
|
||||
options = None
|
||||
|
||||
xla_client.load_pjrt_plugin_dynamically(name, library_path)
|
||||
if lib.xla_extension_version >= 134:
|
||||
return xla_client.make_c_api_client(name, options)
|
||||
else:
|
||||
if options:
|
||||
raise ValueError(
|
||||
'Setting PJRT plugin options through json file requires'
|
||||
' jaxlib.xla_extension_version >= 134.'
|
||||
)
|
||||
return xla_client.make_c_api_client(name)
|
||||
|
||||
return factory
|
||||
|
||||
|
@ -1964,16 +1964,10 @@ def _initialize_outfeed_receiver(
|
||||
device_repr = ", ".join([str(d) for d in devices_with_outfeed])
|
||||
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
|
||||
device_repr, max_callback_queue_size_bytes)
|
||||
if jaxlib.xla_extension_version > 143:
|
||||
# TODO(phawkins): remove type:ignore after minimum jaxlib version bump
|
||||
_callback_handler_data.receiver = outfeed_receiver_module.start(
|
||||
_callback_input_received, tuple(clients_with_outfeed),
|
||||
max_callback_queue_size_bytes,
|
||||
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
|
||||
else:
|
||||
_callback_handler_data.receiver = outfeed_receiver_module.start(
|
||||
_callback_input_received, tuple(clients_with_outfeed),
|
||||
max_callback_queue_size_bytes)
|
||||
|
||||
def exit_handler():
|
||||
# Prevent logging usage during compilation, gives errors under pytest
|
||||
|
@ -129,15 +129,11 @@ def serialize_native(fun_jax: Callable,
|
||||
assert len(module_kept_var_idx) == len(args_avals)
|
||||
mlir_module = compute_dim_vars(mlir_module, args_avals)
|
||||
|
||||
if xla_client.mlir_api_version >= 46:
|
||||
xla_call_module_version = 4
|
||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||
mlir_str, target_version)
|
||||
else:
|
||||
xla_call_module_version = 3
|
||||
mlir_module_serialized = mlir.module_to_bytecode(mlir_module)
|
||||
|
||||
# Figure out the result types and shapes
|
||||
if "global_out_avals" in lowered.compile_args:
|
||||
|
@ -2631,17 +2631,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
if harness.group_name in require_stablehlo_feature_support:
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
# API version 47 supports CHLO ops that decompose into shape dialect ops
|
||||
if xla_client.mlir_api_version < 47:
|
||||
require_stablehlo_feature_support_shape_dialect = {
|
||||
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
|
||||
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
|
||||
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
|
||||
"vmap_nextafter_broadcasting", "vmap_sinh"
|
||||
}
|
||||
if harness.group_name in require_stablehlo_feature_support_shape_dialect:
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
if (jtu.device_under_test() == "tpu" and
|
||||
harness.fullname in [
|
||||
"jnp.cumsum_reduce_axis=poly",
|
||||
|
@ -16,7 +16,7 @@
|
||||
# eval()-ed by setup.py, so it should not have any dependencies.
|
||||
|
||||
__version__ = "0.4.8"
|
||||
_minimum_jaxlib_version = "0.4.6"
|
||||
_minimum_jaxlib_version = "0.4.7"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
@ -19,9 +19,7 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import api_util
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax import dtypes
|
||||
from jax._src import lib as jaxlib
|
||||
from jax import numpy as jnp
|
||||
@ -33,8 +31,6 @@ config.parse_flags_with_absl()
|
||||
|
||||
def _cpp_device_put(value, device):
|
||||
aval = api_util.shaped_abstractify(value)
|
||||
if xla_extension_version < 139:
|
||||
value = xla.canonicalize_dtype(value)
|
||||
return pxla.batched_device_put(
|
||||
aval, jax.sharding.SingleDeviceSharding(device), [value], [device])
|
||||
|
||||
|
@ -39,7 +39,6 @@ from jax.interpreters import xla
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax._src import array
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import pxla
|
||||
@ -2855,8 +2854,6 @@ class FooTyRules:
|
||||
buf, = arr._arrays
|
||||
else:
|
||||
buf, = arr
|
||||
if xla_extension_version < 140:
|
||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
||||
return FooArray(aval.shape, buf)
|
||||
return handler
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user