mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has xla_extension_version 144 and mlir_api_version 47)
This commit is contained in:
parent
2f105bde2d
commit
00acf459c6
@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
## jax 0.4.8
|
## jax 0.4.8
|
||||||
|
|
||||||
|
* Changes
|
||||||
|
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
|
||||||
|
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
|
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
|
||||||
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
|
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
|
||||||
|
@ -64,7 +64,6 @@ from jax._src.lax import lax as lax_internal
|
|||||||
from jax._src.lib import jax_jit
|
from jax._src.lib import jax_jit
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import pmap_lib
|
from jax._src.lib import pmap_lib
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src.sharding_impls import PmapSharding
|
from jax._src.sharding_impls import PmapSharding
|
||||||
from jax._src.traceback_util import api_boundary
|
from jax._src.traceback_util import api_boundary
|
||||||
from jax._src.tree_util import broadcast_prefix, generate_key_paths
|
from jax._src.tree_util import broadcast_prefix, generate_key_paths
|
||||||
@ -2605,8 +2604,6 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
|||||||
raise ValueError("the shards passed to device_put_sharded must have "
|
raise ValueError("the shards passed to device_put_sharded must have "
|
||||||
f"consistent shape and dtype, but got {a1} and {a2}.")
|
f"consistent shape and dtype, but got {a1} and {a2}.")
|
||||||
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
||||||
if xla_extension_version < 139:
|
|
||||||
xs = [xla.canonicalize_dtype(arg) for arg in xs]
|
|
||||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||||
return pxla.batched_device_put(
|
return pxla.batched_device_put(
|
||||||
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
|
stacked_aval, PmapSharding(np.array(devices), sharding_spec),
|
||||||
@ -2656,8 +2653,6 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
|||||||
assert (isinstance(aval, ShapedArray) and
|
assert (isinstance(aval, ShapedArray) and
|
||||||
len(xla.aval_to_xla_shapes(aval)) == 1)
|
len(xla.aval_to_xla_shapes(aval)) == 1)
|
||||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||||
if xla_extension_version < 139:
|
|
||||||
x = xla.canonicalize_dtype(x)
|
|
||||||
buf = jax.device_put(x, devices[0])
|
buf = jax.device_put(x, devices[0])
|
||||||
return pxla.batched_device_put(
|
return pxla.batched_device_put(
|
||||||
aval, PmapSharding(np.array(devices), sharding_spec),
|
aval, PmapSharding(np.array(devices), sharding_spec),
|
||||||
|
@ -32,7 +32,6 @@ from jax._src import profiler
|
|||||||
from jax._src.config import config
|
from jax._src.config import config
|
||||||
from jax._src.util import use_cpp_class, use_cpp_method
|
from jax._src.util import use_cpp_class, use_cpp_method
|
||||||
from jax._src.lib import xla_client as xc
|
from jax._src.lib import xla_client as xc
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src.typing import ArrayLike
|
from jax._src.typing import ArrayLike
|
||||||
from jax.interpreters import mlir
|
from jax.interpreters import mlir
|
||||||
@ -341,14 +340,7 @@ class ArrayImpl(basearray.Array):
|
|||||||
return _single_device_array_from_buf(arr, committed=False)
|
return _single_device_array_from_buf(arr, committed=False)
|
||||||
return lax_numpy._rewriting_take(self, idx)
|
return lax_numpy._rewriting_take(self, idx)
|
||||||
else:
|
else:
|
||||||
if xla_extension_version >= 144:
|
return lax_numpy._rewriting_take(self, idx)
|
||||||
return lax_numpy._rewriting_take(self, idx)
|
|
||||||
else:
|
|
||||||
if (dispatch.is_single_device_sharding(self.sharding) or
|
|
||||||
self.is_fully_replicated or _is_reduced_on_dim(idx)):
|
|
||||||
return lax_numpy._rewriting_take(self, idx)
|
|
||||||
else:
|
|
||||||
return api.device_put(self._value[idx])
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self.ndim == 0:
|
if self.ndim == 0:
|
||||||
@ -404,7 +396,7 @@ class ArrayImpl(basearray.Array):
|
|||||||
'named_shape': self.aval.named_shape}
|
'named_shape': self.aval.named_shape}
|
||||||
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
||||||
|
|
||||||
@use_cpp_method(xla_extension_version >= 138)
|
@use_cpp_method()
|
||||||
def unsafe_buffer_pointer(self):
|
def unsafe_buffer_pointer(self):
|
||||||
if len(self._arrays) != 1:
|
if len(self._arrays) != 1:
|
||||||
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
|
raise ValueError("unsafe_buffer_pointer() is supported only for unsharded"
|
||||||
@ -412,14 +404,14 @@ class ArrayImpl(basearray.Array):
|
|||||||
return self._arrays[0].unsafe_buffer_pointer()
|
return self._arrays[0].unsafe_buffer_pointer()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@use_cpp_method(xla_extension_version >= 138)
|
@use_cpp_method()
|
||||||
def __cuda_array_interface__(self):
|
def __cuda_array_interface__(self):
|
||||||
if len(self._arrays) != 1:
|
if len(self._arrays) != 1:
|
||||||
raise ValueError("__cuda_array_interface__() is supported only for "
|
raise ValueError("__cuda_array_interface__() is supported only for "
|
||||||
"unsharded arrays.")
|
"unsharded arrays.")
|
||||||
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
|
return self._arrays[0].__cuda_array_interface__ # pytype: disable=attribute-error # bind-properties
|
||||||
|
|
||||||
@use_cpp_method(xla_extension_version >= 138)
|
@use_cpp_method()
|
||||||
def on_device_size_in_bytes(self):
|
def on_device_size_in_bytes(self):
|
||||||
"""Returns the total global on-device size of the array in bytes."""
|
"""Returns the total global on-device size of the array in bytes."""
|
||||||
arr = self._arrays[0]
|
arr = self._arrays[0]
|
||||||
@ -495,7 +487,7 @@ class ArrayImpl(basearray.Array):
|
|||||||
out.append(Shard(global_d, self.sharding, self.shape, array))
|
out.append(Shard(global_d, self.sharding, self.shape, array))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@use_cpp_method(xla_extension_version >= 138)
|
@use_cpp_method()
|
||||||
def delete(self):
|
def delete(self):
|
||||||
if self._arrays is None:
|
if self._arrays is None:
|
||||||
return
|
return
|
||||||
@ -524,11 +516,11 @@ class ArrayImpl(basearray.Array):
|
|||||||
db.block_until_ready()
|
db.block_until_ready()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@use_cpp_method(xla_extension_version >= 138)
|
@use_cpp_method()
|
||||||
def _single_device_array_to_np_array(self):
|
def _single_device_array_to_np_array(self):
|
||||||
return np.asarray(self._arrays[0])
|
return np.asarray(self._arrays[0])
|
||||||
|
|
||||||
@use_cpp_method(xla_extension_version >= 138)
|
@use_cpp_method()
|
||||||
def _copy_single_device_array_to_host_async(self):
|
def _copy_single_device_array_to_host_async(self):
|
||||||
self._arrays[0].copy_to_host_async()
|
self._arrays[0].copy_to_host_async()
|
||||||
|
|
||||||
@ -541,10 +533,7 @@ class ArrayImpl(basearray.Array):
|
|||||||
return
|
return
|
||||||
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
||||||
for _, arr in copy_plan:
|
for _, arr in copy_plan:
|
||||||
if xla_extension_version >= 140:
|
arr._copy_single_device_array_to_host_async()
|
||||||
arr._copy_single_device_array_to_host_async()
|
|
||||||
else:
|
|
||||||
arr.copy_to_host_async()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
|
@functools.partial(profiler.annotate_function, name="np.asarray(jax.Array)")
|
||||||
@ -568,17 +557,11 @@ class ArrayImpl(basearray.Array):
|
|||||||
|
|
||||||
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
copy_plan = _create_copy_plan(self._arrays, self.sharding, self.shape)
|
||||||
for _, arr in copy_plan:
|
for _, arr in copy_plan:
|
||||||
if xla_extension_version >= 140:
|
arr._copy_single_device_array_to_host_async()
|
||||||
arr._copy_single_device_array_to_host_async()
|
|
||||||
else:
|
|
||||||
arr.copy_to_host_async()
|
|
||||||
|
|
||||||
npy_value = np.empty(self.shape, self.dtype)
|
npy_value = np.empty(self.shape, self.dtype)
|
||||||
for ind, arr in copy_plan:
|
for ind, arr in copy_plan:
|
||||||
if xla_extension_version >= 140:
|
npy_value[ind] = arr._single_device_array_to_np_array()
|
||||||
npy_value[ind] = arr._single_device_array_to_np_array()
|
|
||||||
else:
|
|
||||||
npy_value[ind] = np.asarray(arr)
|
|
||||||
self._npy_value = npy_value # type: ignore
|
self._npy_value = npy_value # type: ignore
|
||||||
self._npy_value.flags.writeable = False
|
self._npy_value.flags.writeable = False
|
||||||
# https://docs.python.org/3/library/typing.html#typing.cast
|
# https://docs.python.org/3/library/typing.html#typing.cast
|
||||||
|
@ -25,7 +25,6 @@ from jax._src import lib
|
|||||||
from jax._src.lib import jax_jit
|
from jax._src.lib import jax_jit
|
||||||
from jax._src.lib import transfer_guard_lib
|
from jax._src.lib import transfer_guard_lib
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -758,8 +757,6 @@ def _update_jax_array_global(val):
|
|||||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||||
' jax.config.jax_array.')
|
' jax.config.jax_array.')
|
||||||
if xla_extension_version < 141:
|
|
||||||
lib.jax_jit.global_state().jax_array = val
|
|
||||||
|
|
||||||
def _update_jax_array_thread_local(val):
|
def _update_jax_array_thread_local(val):
|
||||||
if val is not None and not val:
|
if val is not None and not val:
|
||||||
@ -767,8 +764,6 @@ def _update_jax_array_thread_local(val):
|
|||||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||||
' jax.config.jax_array.')
|
' jax.config.jax_array.')
|
||||||
if xla_extension_version < 141:
|
|
||||||
lib.jax_jit.thread_local_state().jax_array = val
|
|
||||||
|
|
||||||
jax_array = config.define_bool_state(
|
jax_array = config.define_bool_state(
|
||||||
name='jax_array',
|
name='jax_array',
|
||||||
|
@ -17,7 +17,6 @@ from jax._src import device_array
|
|||||||
from jax._src import array
|
from jax._src import array
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_DTYPES = frozenset({
|
SUPPORTED_DTYPES = frozenset({
|
||||||
@ -40,24 +39,12 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
|
|||||||
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
||||||
owns.
|
owns.
|
||||||
"""
|
"""
|
||||||
if xla_extension_version >= 140:
|
if not isinstance(x, array.ArrayImpl):
|
||||||
if not isinstance(x, array.ArrayImpl):
|
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
||||||
raise TypeError("Argument to to_dlpack must be a jax.Array, "
|
f"got {type(x)}")
|
||||||
f"got {type(x)}")
|
assert len(x.devices()) == 1
|
||||||
assert len(x.devices()) == 1
|
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
||||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
|
||||||
x.addressable_data(0), take_ownership=take_ownership) # type: ignore
|
|
||||||
else:
|
|
||||||
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
|
|
||||||
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, "
|
|
||||||
f"got {type(x)}")
|
|
||||||
if isinstance(x, array.ArrayImpl):
|
|
||||||
assert len(x._arrays) == 1
|
|
||||||
buf = x._arrays[0]
|
|
||||||
else:
|
|
||||||
buf = x.device_buffer
|
|
||||||
return xla_client._xla.buffer_to_dlpack_managed_tensor(
|
|
||||||
buf, take_ownership=take_ownership)
|
|
||||||
|
|
||||||
def from_dlpack(dlpack):
|
def from_dlpack(dlpack):
|
||||||
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
|
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
|
||||||
@ -80,14 +67,5 @@ def from_dlpack(dlpack):
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
gpu_backend = None
|
gpu_backend = None
|
||||||
|
|
||||||
if xla_extension_version >= 140:
|
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||||
return jnp.asarray(xla_client._xla.dlpack_managed_tensor_to_buffer(
|
dlpack, cpu_backend, gpu_backend))
|
||||||
dlpack, cpu_backend, gpu_backend))
|
|
||||||
else:
|
|
||||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
|
||||||
dlpack, cpu_backend, gpu_backend)
|
|
||||||
if isinstance(buf, array.ArrayImpl):
|
|
||||||
return jnp.asarray(buf) # asarray ensures dtype canonicalization
|
|
||||||
else:
|
|
||||||
return jnp.asarray(array._single_device_array_from_buf(
|
|
||||||
buf, committed=buf.device() is not None))
|
|
||||||
|
@ -63,7 +63,7 @@ from jax._src.lax.utils import (
|
|||||||
)
|
)
|
||||||
from jax._src.lib import pytree
|
from jax._src.lib import pytree
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.lib import xla_client, xla_extension_version
|
from jax._src.lib import xla_client
|
||||||
from jax._src.lib.mlir import ir
|
from jax._src.lib.mlir import ir
|
||||||
from jax._src.lib.mlir.dialects import chlo
|
from jax._src.lib.mlir.dialects import chlo
|
||||||
from jax._src.lib.mlir.dialects import hlo
|
from jax._src.lib.mlir.dialects import hlo
|
||||||
@ -4209,14 +4209,8 @@ def _rng_bit_generator_lowering(
|
|||||||
(key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype)
|
(key_shape == [2] and key_etype == u64_type)), (key_shape, key_etype)
|
||||||
dtype = np.dtype(dtype)
|
dtype = np.dtype(dtype)
|
||||||
etype = mlir.dtype_to_ir_type(dtype)
|
etype = mlir.dtype_to_ir_type(dtype)
|
||||||
if (
|
if dtype in (np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
|
||||||
dtype == np.dtype('uint32')
|
np.dtype('uint64')):
|
||||||
or dtype == np.dtype('uint64')
|
|
||||||
or (
|
|
||||||
xla_extension_version >= 140
|
|
||||||
and (dtype == np.dtype('uint16') or dtype == np.dtype('uint8'))
|
|
||||||
)
|
|
||||||
):
|
|
||||||
rbg_etype = etype
|
rbg_etype = etype
|
||||||
else:
|
else:
|
||||||
rbg_etype = u32_type
|
rbg_etype = u32_type
|
||||||
|
@ -32,7 +32,6 @@ from jax._src import core
|
|||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import prng
|
from jax._src import prng
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src.api import jit, vmap
|
from jax._src.api import jit, vmap
|
||||||
from jax._src.core import NamedShape
|
from jax._src.core import NamedShape
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
@ -286,7 +285,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
|
|||||||
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
|
raise TypeError("uniform only accepts 32- or 64-bit dtypes.")
|
||||||
|
|
||||||
rng_bits = nbits
|
rng_bits = nbits
|
||||||
if xla_extension_version >= 140 and nmant < 8:
|
if nmant < 8:
|
||||||
rng_bits = 8
|
rng_bits = 8
|
||||||
bits = _random_bits(key, rng_bits, shape)
|
bits = _random_bits(key, rng_bits, shape)
|
||||||
uint_dtype = UINT_DTYPES[nbits]
|
uint_dtype = UINT_DTYPES[nbits]
|
||||||
|
@ -319,15 +319,7 @@ def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
|
|||||||
options = None
|
options = None
|
||||||
|
|
||||||
xla_client.load_pjrt_plugin_dynamically(name, library_path)
|
xla_client.load_pjrt_plugin_dynamically(name, library_path)
|
||||||
if lib.xla_extension_version >= 134:
|
return xla_client.make_c_api_client(name, options)
|
||||||
return xla_client.make_c_api_client(name, options)
|
|
||||||
else:
|
|
||||||
if options:
|
|
||||||
raise ValueError(
|
|
||||||
'Setting PJRT plugin options through json file requires'
|
|
||||||
' jaxlib.xla_extension_version >= 134.'
|
|
||||||
)
|
|
||||||
return xla_client.make_c_api_client(name)
|
|
||||||
|
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
@ -1964,16 +1964,10 @@ def _initialize_outfeed_receiver(
|
|||||||
device_repr = ", ".join([str(d) for d in devices_with_outfeed])
|
device_repr = ", ".join([str(d) for d in devices_with_outfeed])
|
||||||
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
|
logger.debug("Starting outfeed_receiver for %s. max_callback_queue_size_bytes=%s",
|
||||||
device_repr, max_callback_queue_size_bytes)
|
device_repr, max_callback_queue_size_bytes)
|
||||||
if jaxlib.xla_extension_version > 143:
|
_callback_handler_data.receiver = outfeed_receiver_module.start(
|
||||||
# TODO(phawkins): remove type:ignore after minimum jaxlib version bump
|
_callback_input_received, tuple(clients_with_outfeed),
|
||||||
_callback_handler_data.receiver = outfeed_receiver_module.start(
|
max_callback_queue_size_bytes,
|
||||||
_callback_input_received, tuple(clients_with_outfeed),
|
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
|
||||||
max_callback_queue_size_bytes,
|
|
||||||
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
|
|
||||||
else:
|
|
||||||
_callback_handler_data.receiver = outfeed_receiver_module.start(
|
|
||||||
_callback_input_received, tuple(clients_with_outfeed),
|
|
||||||
max_callback_queue_size_bytes)
|
|
||||||
|
|
||||||
def exit_handler():
|
def exit_handler():
|
||||||
# Prevent logging usage during compilation, gives errors under pytest
|
# Prevent logging usage during compilation, gives errors under pytest
|
||||||
|
@ -129,15 +129,11 @@ def serialize_native(fun_jax: Callable,
|
|||||||
assert len(module_kept_var_idx) == len(args_avals)
|
assert len(module_kept_var_idx) == len(args_avals)
|
||||||
mlir_module = compute_dim_vars(mlir_module, args_avals)
|
mlir_module = compute_dim_vars(mlir_module, args_avals)
|
||||||
|
|
||||||
if xla_client.mlir_api_version >= 46:
|
xla_call_module_version = 4
|
||||||
xla_call_module_version = 4
|
mlir_str = mlir.module_to_bytecode(mlir_module)
|
||||||
mlir_str = mlir.module_to_bytecode(mlir_module)
|
target_version = stablehlo.get_earliest_forward_compatible_version()
|
||||||
target_version = stablehlo.get_earliest_forward_compatible_version()
|
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
||||||
mlir_module_serialized = xla_client._xla.mlir.serialize_portable_artifact(
|
mlir_str, target_version)
|
||||||
mlir_str, target_version)
|
|
||||||
else:
|
|
||||||
xla_call_module_version = 3
|
|
||||||
mlir_module_serialized = mlir.module_to_bytecode(mlir_module)
|
|
||||||
|
|
||||||
# Figure out the result types and shapes
|
# Figure out the result types and shapes
|
||||||
if "global_out_avals" in lowered.compile_args:
|
if "global_out_avals" in lowered.compile_args:
|
||||||
|
@ -2631,17 +2631,6 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
|||||||
if harness.group_name in require_stablehlo_feature_support:
|
if harness.group_name in require_stablehlo_feature_support:
|
||||||
raise unittest.SkipTest(
|
raise unittest.SkipTest(
|
||||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||||
# API version 47 supports CHLO ops that decompose into shape dialect ops
|
|
||||||
if xla_client.mlir_api_version < 47:
|
|
||||||
require_stablehlo_feature_support_shape_dialect = {
|
|
||||||
"vmap_acosh", "vmap_asin", "vmap_asinh", "vmap_atan", "vmap_atanh",
|
|
||||||
"vmap_bessel_i1e", "vmap_cosh", "vmap_digamma", "vmap_erf",
|
|
||||||
"vmap_erfc", "vmap_lgamma", "vmap_nextafter",
|
|
||||||
"vmap_nextafter_broadcasting", "vmap_sinh"
|
|
||||||
}
|
|
||||||
if harness.group_name in require_stablehlo_feature_support_shape_dialect:
|
|
||||||
raise unittest.SkipTest(
|
|
||||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
|
||||||
if (jtu.device_under_test() == "tpu" and
|
if (jtu.device_under_test() == "tpu" and
|
||||||
harness.fullname in [
|
harness.fullname in [
|
||||||
"jnp.cumsum_reduce_axis=poly",
|
"jnp.cumsum_reduce_axis=poly",
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
# eval()-ed by setup.py, so it should not have any dependencies.
|
# eval()-ed by setup.py, so it should not have any dependencies.
|
||||||
|
|
||||||
__version__ = "0.4.8"
|
__version__ = "0.4.8"
|
||||||
_minimum_jaxlib_version = "0.4.6"
|
_minimum_jaxlib_version = "0.4.7"
|
||||||
|
|
||||||
def _version_as_tuple(version_str):
|
def _version_as_tuple(version_str):
|
||||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||||
|
@ -19,9 +19,7 @@ from absl.testing import parameterized
|
|||||||
import jax
|
import jax
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import api_util
|
from jax._src import api_util
|
||||||
from jax._src.interpreters import xla
|
|
||||||
from jax._src.interpreters import pxla
|
from jax._src.interpreters import pxla
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax import dtypes
|
from jax import dtypes
|
||||||
from jax._src import lib as jaxlib
|
from jax._src import lib as jaxlib
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
@ -33,8 +31,6 @@ config.parse_flags_with_absl()
|
|||||||
|
|
||||||
def _cpp_device_put(value, device):
|
def _cpp_device_put(value, device):
|
||||||
aval = api_util.shaped_abstractify(value)
|
aval = api_util.shaped_abstractify(value)
|
||||||
if xla_extension_version < 139:
|
|
||||||
value = xla.canonicalize_dtype(value)
|
|
||||||
return pxla.batched_device_put(
|
return pxla.batched_device_put(
|
||||||
aval, jax.sharding.SingleDeviceSharding(device), [value], [device])
|
aval, jax.sharding.SingleDeviceSharding(device), [value], [device])
|
||||||
|
|
||||||
|
@ -39,7 +39,6 @@ from jax.interpreters import xla
|
|||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
from jax.interpreters import batching
|
from jax.interpreters import batching
|
||||||
from jax._src import array
|
from jax._src import array
|
||||||
from jax._src.lib import xla_extension_version
|
|
||||||
from jax._src.lib.mlir.dialects import hlo
|
from jax._src.lib.mlir.dialects import hlo
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src.interpreters import pxla
|
from jax._src.interpreters import pxla
|
||||||
@ -2855,8 +2854,6 @@ class FooTyRules:
|
|||||||
buf, = arr._arrays
|
buf, = arr._arrays
|
||||||
else:
|
else:
|
||||||
buf, = arr
|
buf, = arr
|
||||||
if xla_extension_version < 140:
|
|
||||||
buf.aval = core.ShapedArray(buf.shape, buf.dtype)
|
|
||||||
return FooArray(aval.shape, buf)
|
return FooArray(aval.shape, buf)
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user