mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56 PiperOrigin-RevId: 631579739
This commit is contained in:
parent
5ba56bb075
commit
395d3cb79e
@ -65,7 +65,6 @@ from jax._src.api_util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
|
||||
@ -545,11 +544,7 @@ def xla_computation(fun: Callable,
|
||||
result_shardings=None,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
|
||||
if xla_extension_version >= 244:
|
||||
m = mlir.module_to_bytecode(lowering_result.module)
|
||||
else:
|
||||
m = mlir.module_to_string(lowering_result.module)
|
||||
|
||||
m = mlir.module_to_bytecode(lowering_result.module)
|
||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
m, use_tuple_args=tuple_args, return_tuple=True)
|
||||
out_shapes_flat = [
|
||||
@ -1812,8 +1807,7 @@ def _cpp_pmap(
|
||||
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple,
|
||||
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
|
||||
pytree_registry=tree_util.default_registry)
|
||||
pxla.shard_arg, pytree_registry=tree_util.default_registry) # type: ignore
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
@ -2912,9 +2906,6 @@ def block_until_ready(x):
|
||||
except AttributeError:
|
||||
return x
|
||||
|
||||
if xla_extension_version < 246:
|
||||
return tree_map(try_to_block, x)
|
||||
|
||||
arrays = []
|
||||
for leaf in tree_leaves(x):
|
||||
if isinstance(leaf, array.ArrayImpl):
|
||||
|
@ -36,7 +36,6 @@ from jax._src import profiler
|
||||
from jax._src import tree_util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
@ -708,21 +707,13 @@ def make_array_from_callback(
|
||||
return pxla.batched_device_put(
|
||||
aval, sharding, per_device_values, devices, committed=True)
|
||||
|
||||
# After minimum jaxlib version >= 0.4.26, merge this condition into the
|
||||
# following if block.
|
||||
if xla_extension_version >= 256 and isinstance(first_value, ArrayImpl):
|
||||
maybe_default_layout = pxla._maybe_get_default_layout(
|
||||
Layout(dll, sharding), None, sharding, aval)
|
||||
layout_eq = first_value.layout.device_local_layout == maybe_default_layout
|
||||
else:
|
||||
layout_eq = True
|
||||
|
||||
if (isinstance(first_value, ArrayImpl)
|
||||
and first_value._committed
|
||||
and sharding.is_fully_replicated
|
||||
and first_value.is_fully_replicated
|
||||
and first_value.sharding._device_assignment == tuple(devices)
|
||||
and layout_eq):
|
||||
and (first_value.layout.device_local_layout ==
|
||||
pxla._maybe_get_default_layout(Layout(dll, sharding), None, sharding, aval))):
|
||||
return first_value
|
||||
|
||||
if dll is not None:
|
||||
|
@ -33,7 +33,6 @@ from jax._src import profiler
|
||||
from jax._src import traceback_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.xla_bridge import process_count
|
||||
import numpy as np
|
||||
@ -253,9 +252,7 @@ def compile_or_get_cached(
|
||||
# Persistent compilation cache only implemented on TPU and GPU and the backend
|
||||
# that supports serialization of executables.
|
||||
# TODO(skye): add warning when initializing cache on unsupported default platform
|
||||
supported_platforms = ["tpu", "gpu"]
|
||||
if xla_extension_version >= 253:
|
||||
supported_platforms.append("cpu")
|
||||
supported_platforms = ["tpu", "gpu", "cpu"]
|
||||
use_compilation_cache = (
|
||||
config.enable_compilation_cache.value
|
||||
and getattr(backend, "supports_executable_serialization", True)
|
||||
|
@ -22,7 +22,6 @@ from jax._src import array
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lax.lax import _array_copy
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.typing import Array, DLDeviceType
|
||||
from jax._src.sharding import Sharding
|
||||
|
||||
@ -39,10 +38,7 @@ MIN_DLPACK_VERSION = (0, 5)
|
||||
SUPPORTED_DTYPES = frozenset({
|
||||
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
|
||||
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
|
||||
jnp.float64, jnp.complex64, jnp.complex128})
|
||||
|
||||
if xla_extension_version >= 231:
|
||||
SUPPORTED_DTYPES = SUPPORTED_DTYPES | frozenset({jnp.bool_})
|
||||
jnp.float64, jnp.complex64, jnp.complex128, jnp.bool_})
|
||||
|
||||
|
||||
def _to_dlpack(x: Array, stream: int | Any | None,
|
||||
|
@ -49,7 +49,6 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
@ -98,7 +97,7 @@ def dense_int_array(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
||||
|
||||
# TODO: b/321794305 - delete this when jaxlib is on StableHLO API v6 or higher
|
||||
def dense_int_array_v6(xs) -> ir.DenseIntElementsAttr | ir.DenseI64ArrayAttr:
|
||||
if hlo.get_api_version() < 6 or xc.mlir_api_version < 55:
|
||||
if hlo.get_api_version() < 6:
|
||||
return dense_int_elements(xs)
|
||||
return ir.DenseI64ArrayAttr.get(np.asarray(xs, np.int64))
|
||||
|
||||
@ -113,7 +112,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
|
||||
|
||||
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseElementsAttr | ir.DenseBoolArrayAttr:
|
||||
# TODO: b/321794305 - remove this check when jaxlib is on StableHLO API v6 or higher
|
||||
if hlo.get_api_version() < 6 or xc.mlir_api_version < 55:
|
||||
if hlo.get_api_version() < 6:
|
||||
return dense_bool_elements(xs)
|
||||
return ir.DenseBoolArrayAttr.get(xs)
|
||||
|
||||
@ -971,9 +970,7 @@ def lower_jaxpr_to_module(
|
||||
attrs["sym_name"] = ir.StringAttr.get(module_name)
|
||||
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
|
||||
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
|
||||
replace_tokens_with_dummy = lowering_parameters.replace_tokens_with_dummy
|
||||
if xla_extension_version >= 260:
|
||||
replace_tokens_with_dummy = False
|
||||
replace_tokens_with_dummy = False
|
||||
lower_jaxpr_to_fun(
|
||||
ctx, "main", jaxpr, ordered_effects,
|
||||
name_stack=name_stack,
|
||||
|
@ -62,7 +62,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -1800,15 +1799,9 @@ class SemanticallyEqualShardings:
|
||||
|
||||
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
|
||||
avals: tuple[core.AbstractValue]):
|
||||
if xla_extension_version < 241:
|
||||
gspmd_shardings = [
|
||||
s if is_unspecified_or_auto(s) or a is core.abstract_token
|
||||
else to_gspmd_sharding(s, a.ndim) # type: ignore
|
||||
for s, a in zip(shardings, avals)]
|
||||
else:
|
||||
gspmd_shardings = [
|
||||
s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore
|
||||
for s, a in zip(shardings, avals)]
|
||||
gspmd_shardings = [
|
||||
s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore
|
||||
for s, a in zip(shardings, avals)]
|
||||
self._gspmd_shardings = gspmd_shardings
|
||||
self.shardings = shardings
|
||||
self.avals = avals
|
||||
@ -2017,13 +2010,9 @@ def to_gspmd_sharding(s: sharding_impls.XLACompatibleSharding,
|
||||
ndim: int) -> GSPMDSharding:
|
||||
if isinstance(s, GSPMDSharding):
|
||||
return s
|
||||
if xla_extension_version >= 234:
|
||||
return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||
memory_kind=s.memory_kind,
|
||||
_device_list=getattr(s, '_internal_device_list', None))
|
||||
else:
|
||||
return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||
memory_kind=s.memory_kind)
|
||||
return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
|
||||
memory_kind=s.memory_kind,
|
||||
_device_list=getattr(s, '_internal_device_list', None))
|
||||
|
||||
|
||||
# Dummy function which is a no-op in OSS since enhanced barrier is switched on
|
||||
@ -2121,10 +2110,6 @@ def lower_sharding_computation(
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
|
||||
if xla_extension_version < 241:
|
||||
gs = GSPMDSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
all_default_mem_kind = are_all_shardings_default_mem_kind(
|
||||
@ -2393,38 +2378,6 @@ class MeshComputation(stages.XlaLowering):
|
||||
return xe.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
|
||||
|
||||
|
||||
if xla_extension_version < 229:
|
||||
def _get_input_indices(
|
||||
avals: Sequence[ShapedArray],
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding],
|
||||
da_object: xc.DeviceList | Sequence[xc.Device], # type: ignore
|
||||
) -> Sequence[tuple[Index | None, ...]]:
|
||||
|
||||
input_indices = []
|
||||
if not isinstance(da_object, xc.DeviceList):
|
||||
da_object = _create_da_object(tuple(da_object))
|
||||
num_addressable_devices = len(da_object.addressable_device_list)
|
||||
|
||||
def _get_replicated_slices(num_addressable_devices: int, ndim: int | None):
|
||||
if ndim is None:
|
||||
return ((slice(None),),) * num_addressable_devices
|
||||
else:
|
||||
return ((slice(None),) * ndim,) * num_addressable_devices
|
||||
|
||||
for aval, sharding in zip(avals, shardings):
|
||||
if aval is core.abstract_token:
|
||||
index = _get_replicated_slices(num_addressable_devices, None)
|
||||
else:
|
||||
if sharding.is_fully_replicated:
|
||||
index = _get_replicated_slices(num_addressable_devices, aval.ndim)
|
||||
else:
|
||||
index = tuple(
|
||||
sharding.addressable_devices_indices_map(aval.shape).values()) # type: ignore
|
||||
input_indices.append(index)
|
||||
|
||||
return input_indices
|
||||
|
||||
|
||||
def get_out_shardings_from_executable(
|
||||
xla_executable,
|
||||
device_assignment: Sequence[xc.Device],
|
||||
@ -2716,8 +2669,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
get_logical_mesh_ids(list(mesh.shape.values()))
|
||||
.reshape(-1))
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
if xla_extension_version >= 241:
|
||||
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
|
||||
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
|
||||
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
@ -2835,13 +2787,7 @@ class UnloadedMeshExecutable:
|
||||
all_args_info: AllArgsInfo | None
|
||||
|
||||
def build_unsafe_call(self):
|
||||
if xla_extension_version >= 229:
|
||||
handle_args = InputsHandler(self.input_shardings)
|
||||
else:
|
||||
input_indices = _get_input_indices(self.input_avals, self.input_shardings,
|
||||
self.device_assignment)
|
||||
handle_args = InputsHandler(
|
||||
self.input_shardings, self.xla_executable.local_devices(), input_indices)
|
||||
handle_args = InputsHandler(self.input_shardings)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
self.output_avals, self.output_shardings, self.committed) # type: ignore # arg-type
|
||||
|
||||
@ -2928,10 +2874,9 @@ class UnloadedMeshExecutable:
|
||||
else:
|
||||
if pmap_nreps == 1:
|
||||
assert mesh is None
|
||||
if xla_extension_version >= 241:
|
||||
in_shardings = _maybe_get_and_check_in_shardings(
|
||||
xla_executable, in_shardings, tuple(da), global_in_avals,
|
||||
len(ordered_effects))
|
||||
in_shardings = _maybe_get_and_check_in_shardings(
|
||||
xla_executable, in_shardings, tuple(da), global_in_avals,
|
||||
len(ordered_effects))
|
||||
out_shardings = _maybe_get_and_check_out_shardings(
|
||||
xla_executable, out_shardings, tuple(da), global_out_avals,
|
||||
len(ordered_effects), all_default_mem_kind)
|
||||
@ -3096,19 +3041,9 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
if xla_extension_version >= 226:
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry,
|
||||
shard_arg if xla_extension_version >= 229 else temp_shard_arg) # type: ignore
|
||||
else:
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [], # type: ignore
|
||||
tree_util.dispatch_registry)
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
|
||||
def temp_shard_arg(arg, devices, arg_indices, sharding, canonicalize=True):
|
||||
return shard_arg(arg, sharding)
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, shard_arg) # type: ignore
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
@ -3216,7 +3151,7 @@ def check_array_xla_sharding_layout_match(
|
||||
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
|
||||
'sharding'))
|
||||
|
||||
if (xla_extension_version >= 249 and not db_xs and arg._committed and
|
||||
if (not db_xs and arg._committed and
|
||||
arg.layout.device_local_layout is not None and xl is not None and
|
||||
arg.layout.device_local_layout != xl):
|
||||
errors.append(
|
||||
|
@ -62,7 +62,7 @@ from jax._src.core import ShapedArray, ConcreteArray
|
||||
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
|
||||
_sort_le_comparator, PrecisionLike)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_client as xc, xla_extension_version
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
@ -2526,19 +2526,14 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
if hasattr(object, '__jax_array__'):
|
||||
object = object.__jax_array__()
|
||||
elif hasattr(object, '__cuda_array_interface__'):
|
||||
if xla_extension_version >= 237:
|
||||
cai = object.__cuda_array_interface__
|
||||
backend = xla_bridge.get_backend("cuda")
|
||||
if cuda_plugin_extension is None:
|
||||
device_id = None
|
||||
else:
|
||||
device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0])
|
||||
if xla_extension_version >= 261:
|
||||
object = xc._xla.cuda_array_interface_to_buffer(
|
||||
cai=cai, gpu_backend=backend, device_id=device_id
|
||||
)
|
||||
else:
|
||||
object = xc._xla.cuda_array_interface_to_buffer(cai, backend)
|
||||
cai = object.__cuda_array_interface__
|
||||
backend = xla_bridge.get_backend("cuda")
|
||||
if cuda_plugin_extension is None:
|
||||
device_id = None
|
||||
else:
|
||||
device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0])
|
||||
object = xc._xla.cuda_array_interface_to_buffer(
|
||||
cai=cai, gpu_backend=backend, device_id=device_id)
|
||||
|
||||
object = tree_map(lambda leaf: leaf.__jax_array__()
|
||||
if hasattr(leaf, "__jax_array__") else leaf, object)
|
||||
|
@ -61,7 +61,6 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
@ -311,19 +310,12 @@ def _cpp_pjit(jit_info: PjitInfo):
|
||||
return outs, maybe_fastpath_data
|
||||
|
||||
fun = jit_info.fun
|
||||
if xla_extension_version >= 226:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry,
|
||||
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
|
||||
_get_cpp_global_cache(jit_info.has_explicit_sharding)) # type: ignore
|
||||
else:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(jit_info.has_explicit_sharding))
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry,
|
||||
pxla.shard_arg, # type: ignore
|
||||
_get_cpp_global_cache(jit_info.has_explicit_sharding)) # type: ignore
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
cpp_pjitted_f._fun = fun
|
||||
@ -1548,16 +1540,11 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, None, None)
|
||||
if xla_extension_version >= 226:
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry,
|
||||
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
else:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
|
||||
tree_util.dispatch_registry,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry,
|
||||
pxla.shard_arg, # type: ignore
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
@ -1807,12 +1794,9 @@ def _pjit_batcher_for_sharding(
|
||||
tad = list(new_op.tile_assignment_dimensions)
|
||||
tad.insert(dim, 1)
|
||||
new_op.tile_assignment_dimensions = tad
|
||||
if xla_extension_version >= 234:
|
||||
new_gs = GSPMDSharding(
|
||||
s._device_assignment, new_op, # type: ignore
|
||||
_device_list=getattr(s, '_internal_device_list', None))
|
||||
else:
|
||||
new_gs = GSPMDSharding(s._device_assignment, new_op) # type: ignore
|
||||
new_gs = GSPMDSharding(
|
||||
s._device_assignment, new_op, # type: ignore
|
||||
_device_list=getattr(s, '_internal_device_list', None))
|
||||
return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] # type: ignore
|
||||
else:
|
||||
if isinstance(s, NamedSharding):
|
||||
|
@ -23,7 +23,6 @@ from jax._src.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.lib.mlir import ir
|
||||
|
||||
@ -31,8 +30,6 @@ _next_shard_group_id = itertools.count()
|
||||
|
||||
def shard_alike(x, y):
|
||||
"""Shards x and y alike."""
|
||||
if xla_extension_version < 227:
|
||||
raise ValueError("shard_alike requires jaxlib v0.4.24 or newer.")
|
||||
x_flat, x_tree = tree_flatten(x)
|
||||
y_flat, y_tree = tree_flatten(y)
|
||||
|
||||
|
@ -35,7 +35,6 @@ from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
|
||||
import numpy as np
|
||||
@ -190,10 +189,6 @@ def named_sharding_to_xla_hlo_sharding(
|
||||
axis_names = self.mesh.axis_names
|
||||
for manual_axis in self._manual_axes:
|
||||
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
|
||||
if xla_extension_version < 259:
|
||||
if manual_axis in array_mapping: # type: ignore
|
||||
raise ValueError(f"manual axis {repr(manual_axis)} in {repr(self)} "
|
||||
"cannot be used as a sharded axis")
|
||||
|
||||
replicated_mesh_axes = []
|
||||
for i, (axis_name, axis_val) in enumerate(mesh_shape.items()):
|
||||
@ -297,13 +292,6 @@ class NamedSharding(XLACompatibleSharding):
|
||||
self._manual_axes = _manual_axes
|
||||
self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec)
|
||||
|
||||
# TODO(phawkins): remove this method when jaxlib 0.4.26 or newer is the
|
||||
# minimum. This method is called by the C++ sharding implementation in earlier
|
||||
# versions.
|
||||
if xla_extension_version < 243:
|
||||
def _preprocess(self):
|
||||
self._parsed_pspec = preprocess(self.mesh, self.spec, self._parsed_pspec)
|
||||
|
||||
def __repr__(self):
|
||||
mesh_repr = ", ".join(f"'{k}': {v}" for k, v in self.mesh.shape.items())
|
||||
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
|
||||
@ -683,8 +671,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
|
||||
def __init__(self, devices: Sequence[xc.Device] | np.ndarray,
|
||||
*, memory_kind: str | None = None):
|
||||
if xla_extension_version >= 235:
|
||||
super().__init__()
|
||||
super().__init__()
|
||||
if not isinstance(devices, np.ndarray):
|
||||
devices = np.array(devices, dtype='object')
|
||||
if not devices.size:
|
||||
|
@ -48,7 +48,6 @@ from jax._src.layout import Layout
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
@ -315,10 +314,7 @@ class XlaLowering(Lowering):
|
||||
"""Return an HLO representation of this computation."""
|
||||
hlo = self.stablehlo()
|
||||
m: Union[str, bytes]
|
||||
if xla_extension_version >= 244:
|
||||
m = mlir.module_to_bytecode(hlo)
|
||||
else:
|
||||
m = mlir.module_to_string(hlo)
|
||||
m = mlir.module_to_bytecode(hlo)
|
||||
return xla_extension.mlir.mlir_module_to_xla_computation(
|
||||
m, use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
|
@ -25,7 +25,6 @@ from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import unzip2
|
||||
|
||||
@ -583,50 +582,26 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
||||
return result
|
||||
|
||||
if xla_extension_version >= 248:
|
||||
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
"""
|
||||
out = default_registry.flatten_one_level(pytree)
|
||||
if out is None:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
else:
|
||||
return out
|
||||
else:
|
||||
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
"""
|
||||
handler = _registry.get(type(pytree))
|
||||
if handler:
|
||||
children, meta = handler.to_iter(pytree)
|
||||
return list(children), meta
|
||||
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
return [getattr(pytree, s) for s in pytree._fields], None
|
||||
else:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
"""
|
||||
out = default_registry.flatten_one_level(pytree)
|
||||
if out is None:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
else:
|
||||
return out
|
||||
|
||||
def prefix_errors(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
@ -876,11 +851,6 @@ def register_dataclass(
|
||||
meta_fields: auxiliary data field names.
|
||||
data_fields: data field names.
|
||||
"""
|
||||
if xla_extension_version < 259:
|
||||
raise NotImplementedError(
|
||||
"Registering dataclasses is only supported in jaxlib>=0.4.26."
|
||||
)
|
||||
|
||||
def flatten_with_keys(x):
|
||||
meta = tuple(getattr(x, name) for name in meta_fields)
|
||||
data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields)
|
||||
@ -969,64 +939,37 @@ _generate_key_paths = generate_key_paths # alias for backward compat
|
||||
|
||||
|
||||
# The overall logic should be same as PyTreeDef::FlattenIntoImpl
|
||||
if xla_extension_version >= 248:
|
||||
def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
key_handler = _registry_with_keypaths.get(type(tree))
|
||||
if key_handler:
|
||||
key_children, _ = key_handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
return
|
||||
|
||||
flat = default_registry.flatten_one_level(tree)
|
||||
if flat is None:
|
||||
yield key_path, tree # strict leaf type
|
||||
return
|
||||
|
||||
if (isinstance(tree, tuple) and hasattr(tree, '_fields') and
|
||||
flat[1] == type(tree)):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
return
|
||||
|
||||
for i, c in enumerate(flat[0]):
|
||||
k = FlattenedIndexKey(i)
|
||||
def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
key_handler = _registry_with_keypaths.get(type(tree))
|
||||
if key_handler:
|
||||
key_children, _ = key_handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
else:
|
||||
def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
key_handler = _registry_with_keypaths.get(type(tree))
|
||||
if key_handler:
|
||||
key_children, _ = key_handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif handler := _registry.get(type(tree)):
|
||||
children, _ = handler.to_iter(tree)
|
||||
for i, c in enumerate(children):
|
||||
k = FlattenedIndexKey(i)
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
else:
|
||||
yield key_path, tree # strict leaf type
|
||||
return
|
||||
|
||||
flat = default_registry.flatten_one_level(tree)
|
||||
if flat is None:
|
||||
yield key_path, tree # strict leaf type
|
||||
return
|
||||
|
||||
if (isinstance(tree, tuple) and hasattr(tree, '_fields') and
|
||||
flat[1] == type(tree)):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
return
|
||||
|
||||
for i, c in enumerate(flat[0]):
|
||||
k = FlattenedIndexKey(i)
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
|
||||
|
||||
def tree_map_with_path(f: Callable[..., Any],
|
||||
|
@ -47,7 +47,6 @@ from jax._src.cloud_tpu_init import maybe_import_libtpu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import jaxlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -250,25 +249,17 @@ def make_cpu_client() -> xla_client.Client:
|
||||
collectives = xla_client._xla.make_gloo_tcp_collectives( # type: ignore
|
||||
distributed_client=distributed.global_state.client,
|
||||
)
|
||||
elif collectives_impl == 'mpi' and xla_extension_version >= 251:
|
||||
elif collectives_impl == 'mpi':
|
||||
collectives = xla_client._xla.make_mpi_collectives() # type: ignore
|
||||
collectives.Init() # type: ignore
|
||||
atexit.register(collectives.Finalize) # type: ignore
|
||||
elif collectives_impl != 'none':
|
||||
collectives_impls = ['none', 'gloo'
|
||||
] + (['mpi'] if xla_extension_version >= 251 else [])
|
||||
collectives_impls = ['none', 'gloo', 'mpi']
|
||||
raise RuntimeError(f"Unknown collectives implementation "
|
||||
f"{collectives_impl}. Available implementations are "
|
||||
f"{collectives_impls}.")
|
||||
if xla_extension_version >= 257:
|
||||
return xla_client.make_cpu_client( # type: ignore
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
collectives=collectives,
|
||||
)
|
||||
return xla_client.make_cpu_client( # type: ignore
|
||||
asynchronous=_CPU_ENABLE_ASYNC_DISPATCH.value,
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
@ -703,13 +694,10 @@ def register_plugin(
|
||||
c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) # type: ignore
|
||||
xla_client.profiler.register_plugin_profiler(c_api)
|
||||
else:
|
||||
if xla_extension_version >= 236:
|
||||
assert c_api is not None
|
||||
xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api)
|
||||
if xla_extension_version >= 239:
|
||||
make_topology = partial(xla_client.make_c_api_device_topology, c_api)
|
||||
else:
|
||||
make_topology = None
|
||||
assert c_api is not None
|
||||
xla_client.load_pjrt_plugin_with_c_api(plugin_name, c_api)
|
||||
|
||||
make_topology = partial(xla_client.make_c_api_device_topology, c_api)
|
||||
experimental = plugin_name not in _nonexperimental_plugins
|
||||
register_backend_factory(plugin_name, factory, priority=priority,
|
||||
fail_quietly=False, experimental=experimental,
|
||||
|
@ -35,7 +35,6 @@ from jax._src.sharding_impls import _op_sharding_to_pos_sharding
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
|
||||
|
||||
|
||||
@ -191,10 +190,6 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
|
||||
)
|
||||
result_sharding = _pack_result_sharding(result_shape, result_shardings)
|
||||
if xla_extension_version < 232:
|
||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(module), use_tuple_args=False, return_tuple=False)
|
||||
return built, arg_shardings, result_sharding
|
||||
return mlir.module_to_bytecode(module), arg_shardings, result_sharding
|
||||
|
||||
|
||||
@ -548,14 +543,13 @@ xc.register_custom_call_partitioner( # pytype: disable=module-attr
|
||||
_custom_partitioning_propagate_user_sharding,
|
||||
_custom_partitioning_partition,
|
||||
_custom_partitioning_infer_sharding_from_operands, True)
|
||||
if xla_extension_version >= 252:
|
||||
xb.register_plugin_callbacks(
|
||||
partial(
|
||||
xc.register_custom_call_partitioner,
|
||||
name=_CUSTOM_PARTITIONING_CALL_NAME,
|
||||
prop_user_sharding=_custom_partitioning_propagate_user_sharding,
|
||||
partition=_custom_partitioning_partition,
|
||||
infer_sharding_from_operands=_custom_partitioning_infer_sharding_from_operands,
|
||||
can_side_effecting_have_replicated_sharding=True,
|
||||
)
|
||||
)
|
||||
xb.register_plugin_callbacks(
|
||||
partial(
|
||||
xc.register_custom_call_partitioner,
|
||||
name=_CUSTOM_PARTITIONING_CALL_NAME,
|
||||
prop_user_sharding=_custom_partitioning_propagate_user_sharding,
|
||||
partition=_custom_partitioning_partition,
|
||||
infer_sharding_from_operands=_custom_partitioning_infer_sharding_from_operands,
|
||||
can_side_effecting_have_replicated_sharding=True,
|
||||
)
|
||||
)
|
||||
|
@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path):
|
||||
|
||||
|
||||
__version__ = _get_version_string()
|
||||
_minimum_jaxlib_version = "0.4.23"
|
||||
_minimum_jaxlib_version = "0.4.27"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
@ -22,7 +22,6 @@ import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -237,7 +236,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("cuda")
|
||||
@unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib")
|
||||
def testCudaArrayInterfaceOnNonCudaFails(self):
|
||||
x = jnp.arange(5)
|
||||
self.assertFalse(hasattr(x, "__cuda_array_interface__"))
|
||||
@ -248,7 +246,6 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
_ = x.__cuda_array_interface__
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
@unittest.skipIf(xla_extension_version < 233, "Requires newer jaxlib")
|
||||
def testCudaArrayInterfaceOnShardedArrayFails(self):
|
||||
devices = jax.local_devices()
|
||||
if len(devices) <= 1:
|
||||
@ -280,7 +277,6 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
self.assertEqual(z.__array_interface__["typestr"], a["typestr"])
|
||||
|
||||
@jtu.run_on_devices("cuda")
|
||||
@unittest.skipIf(xla_extension_version < 228, "Requires newer jaxlib")
|
||||
def testCudaArrayInterfaceBfloat16Fails(self):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng((2, 2), jnp.bfloat16)
|
||||
@ -303,7 +299,6 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
z.__cuda_array_interface__["data"][0])
|
||||
self.assertAllClose(x, cupy.asnumpy(z))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 237, "Requires newer jaxlib")
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=jtu.dtypes.supported(cuda_array_interface_dtypes),
|
||||
@ -320,7 +315,6 @@ class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
z.__cuda_array_interface__["data"][0])
|
||||
self.assertAllClose(np.asarray(z), cupy.asnumpy(y))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 237, "Requires newer jaxlib")
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=jtu.dtypes.supported(cuda_array_interface_dtypes),
|
||||
|
@ -37,7 +37,6 @@ from jax._src import monitoring
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -72,9 +71,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
supported_platforms = ["tpu", "gpu"]
|
||||
if xla_extension_version >= 253:
|
||||
supported_platforms.append("cpu")
|
||||
supported_platforms = ["tpu", "gpu", "cpu"]
|
||||
|
||||
if not jtu.test_device_matches(supported_platforms):
|
||||
raise SkipTest(
|
||||
|
@ -40,7 +40,6 @@ from jax._src import effects
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
@ -1261,8 +1260,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
mlir_outer_module_str = str(lowered_outer.compiler_ir())
|
||||
if exp.mlir_module_serialization_version >= _export._VERSION_START_SUPPORT_EFFECTS_WITH_REAL_TOKENS:
|
||||
if xla_extension_version < 260:
|
||||
main_expected_re = main_expected_re.replace("!stablehlo.token", "tensor<0xi1>")
|
||||
self.assertRegex(mlir_outer_module_str, main_expected_re)
|
||||
|
||||
res = jax.jit(f_outer)(x)
|
||||
|
@ -20,7 +20,6 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -36,7 +35,6 @@ class GpuMemoryAllocationTest(absltest.TestCase):
|
||||
"XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ,
|
||||
"Test does not work if the python client allocator has been overriden",
|
||||
)
|
||||
@unittest.skipIf(xla_extension_version < 225, "jaxlib version too old")
|
||||
def test_gpu_memory_allocation(self):
|
||||
falsey_values = ("0", "False", "false")
|
||||
preallocate = (
|
||||
|
@ -31,7 +31,6 @@ from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.maps import xmap
|
||||
import numpy as np
|
||||
|
||||
@ -374,56 +373,6 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
'incorrect set of output token.'):
|
||||
f.lower(2.)
|
||||
|
||||
def test_lowering_ordered_effect_should_create_tokens(self):
|
||||
if xla_extension_version >= 260:
|
||||
self.skipTest('Not applicable anymore')
|
||||
def effect_lowering(ctx, *, effect):
|
||||
ctx.set_tokens_out(ctx.tokens_in)
|
||||
return []
|
||||
mlir.register_lowering(effect_p, effect_lowering)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=foo2_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
second_op = main.body.blocks[0].operations[1]
|
||||
self.assertIn('hlo.create_token', second_op.operation.name)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=foo2_effect)
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
second_op = main.body.blocks[0].operations[1]
|
||||
self.assertIn('hlo.create_token', second_op.operation.name)
|
||||
|
||||
def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self):
|
||||
|
||||
mlir.register_lowering(effect_p, function_effect_lowering)
|
||||
@ -434,13 +383,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
return x + 1.
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
|
||||
if xla_extension_version < 260:
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
call_op = main.body.blocks[0].operations[1]
|
||||
else:
|
||||
call_op = main.body.blocks[0].operations[0]
|
||||
call_op = main.body.blocks[0].operations[0]
|
||||
|
||||
self.assertEqual(call_op.operation.name, 'func.call')
|
||||
self.assertEqual(str(call_op.attributes['callee']), '@effect')
|
||||
@ -491,9 +434,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
return x + 1.
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
token_type = (
|
||||
'!stablehlo.token' if xla_extension_version >= 260 else 'tensor<0xi1>'
|
||||
)
|
||||
token_type = '!stablehlo.token'
|
||||
# First argument should be a token
|
||||
self.assertLen(list(input_types), 2)
|
||||
self.assertEqual(str(input_types[0]), token_type)
|
||||
@ -511,9 +452,7 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
return x + 1.
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
token_type = (
|
||||
'!stablehlo.token' if xla_extension_version >= 260 else 'tensor<0xi1>'
|
||||
)
|
||||
token_type = '!stablehlo.token'
|
||||
# First two arguments should be token values
|
||||
self.assertLen(list(input_types), 3)
|
||||
self.assertEqual(str(input_types[0]), token_type)
|
||||
|
@ -45,7 +45,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import NumpyComplexWarning
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -3616,7 +3615,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
elif name == 'log10':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag', 'zero.imag')
|
||||
|
||||
elif name == 'log1p' and xla_extension_version < 254:
|
||||
elif name == 'log1p':
|
||||
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'neg.real', 'pos.real',
|
||||
'negj.real', 'posj.real', 'ninf.real', 'ninfj.real', 'pinfj.real')
|
||||
|
||||
@ -3629,7 +3628,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
if dtype == np.complex128:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'pos.imag', 'negj', 'posj', 'ninf', 'pinf', 'mpos.imag')
|
||||
|
||||
elif name == 'expm1' and xla_extension_version < 250:
|
||||
elif name == 'expm1':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'pinf', 'mq1', 'mq2', 'mq3', 'mq4', 'mneg', 'mpos')
|
||||
|
||||
elif name == 'sinc':
|
||||
@ -3694,10 +3693,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj')
|
||||
|
||||
elif name == 'arctanh':
|
||||
if xla_extension_version < 254:
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
|
||||
else:
|
||||
regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
|
||||
regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag')
|
||||
|
||||
elif name in {'cos', 'sin'}:
|
||||
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
|
||||
@ -3741,11 +3737,11 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(
|
||||
normalized_result_slice, normalized_expected_slice, atol=atol,
|
||||
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}\n{inexact_samples}")
|
||||
err_msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=},\n{inexact_samples}")
|
||||
|
||||
if kind == 'failure' and region_name in regions_with_inaccuracies:
|
||||
try:
|
||||
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}, {xla_extension_version=}"):
|
||||
with self.assertRaises(AssertionError, msg=f"{name} in {region_name}, {is_cpu=} {is_cuda=}"):
|
||||
with jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered in.*"):
|
||||
self.assertAllClose(normalized_result_slice, normalized_expected_slice)
|
||||
except AssertionError as msg:
|
||||
|
@ -23,7 +23,6 @@ import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import config
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.numpy as jnp
|
||||
@ -1250,7 +1249,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
||||
compiled_stats = compiled_f.memory_analysis()
|
||||
if compiled_stats is not None and jtu.test_device_matches(["tpu"]):
|
||||
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
if jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
||||
|
||||
def test_remat_scan_jaxpr_offloadable(self):
|
||||
@ -1309,7 +1308,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
||||
compiled_stats = compiled_f.memory_analysis()
|
||||
if compiled_stats is not None:
|
||||
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
if jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
||||
|
||||
def test_remat_scan_layout_change_offloadable(self):
|
||||
@ -1351,12 +1350,10 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
||||
compiled_stats = compiled_f.memory_analysis()
|
||||
if compiled_stats is not None:
|
||||
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
if jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
||||
|
||||
def test_remat_checkpoint_dots_with_no_batch_dims(self):
|
||||
if not jtu.test_device_matches(["tpu"]) and xla_extension_version < 247:
|
||||
self.skipTest("Test requires a newer jaxlib")
|
||||
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
|
||||
"device", "pinned_host")
|
||||
|
||||
@ -1385,7 +1382,7 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
||||
compiled_stats = compiled_f.memory_analysis()
|
||||
if compiled_stats is not None and jtu.test_device_matches(["tpu"]):
|
||||
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
if jtu.pjrt_c_api_version_at_least(0, 43):
|
||||
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -59,7 +59,6 @@ from jax.interpreters import mlir
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -420,7 +419,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
jax.tree.map(self.assertDeleted, y_tree)
|
||||
jax.tree.map(self.assertDeleted, z_tree)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 220, 'jaxlib version too old')
|
||||
@jtu.run_on_devices('tpu')
|
||||
def testBufferDonationWithOutputShardingInferenceAndTokens(self):
|
||||
mesh = jtu.create_global_mesh((2,), 'x')
|
||||
@ -4027,9 +4025,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
f(inps) # doesn't crash
|
||||
|
||||
def test_spmd_preserves_input_sharding_vmap_grad(self):
|
||||
if xla_extension_version < 258:
|
||||
self.skipTest('Requires xla_extension_version >= 258')
|
||||
|
||||
# https://github.com/google/jax/issues/20710
|
||||
n_devices = jax.device_count()
|
||||
sharding = PositionalSharding(jax.devices())
|
||||
@ -4068,8 +4063,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
class TempSharding(Sharding):
|
||||
|
||||
def __init__(self, devices):
|
||||
if xla_extension_version >= 235:
|
||||
super().__init__()
|
||||
super().__init__()
|
||||
self._devices = devices
|
||||
self._internal_device_list = xc.DeviceList(tuple(self._devices))
|
||||
|
||||
|
@ -52,7 +52,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lax import parallel
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -1112,7 +1111,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter
|
||||
|
||||
@jtu.run_on_devices("gpu")
|
||||
@unittest.skipIf(xla_extension_version < 250, "Requires jaxlib 0.4.26")
|
||||
def testCollectiveBroadcast(self):
|
||||
device_count = jax.device_count()
|
||||
f = lambda x: lax.pbroadcast(x, source=0, axis_name='i')
|
||||
@ -1123,7 +1121,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@jtu.run_on_devices("gpu")
|
||||
@unittest.skipIf(xla_extension_version < 250, "Requires jaxlib 0.4.26")
|
||||
def testCollectiveBroadcastVmap(self):
|
||||
device_count = jax.device_count()
|
||||
f = lambda x: lax.pbroadcast(x, source=0, axis_name='i')
|
||||
@ -1134,7 +1131,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@jtu.run_on_devices("gpu")
|
||||
@unittest.skipIf(xla_extension_version < 250, "Requires jaxlib 0.4.26")
|
||||
def testCollectiveBroadcastGrad(self):
|
||||
device_count = jax.device_count()
|
||||
f = lambda x: lax.pbroadcast(x, source=0, axis_name='i')
|
||||
@ -2200,7 +2196,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
jax.vmap(jax.vmap(lambda x: 2 * x, axis_name='i'),
|
||||
axis_name='i')(jax.numpy.ones((1, 1))) # don't crash
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 252, "Requires jaxlib 0.4.26")
|
||||
@jtu.run_on_devices("cpu")
|
||||
def test_pmap_stack_size(self):
|
||||
# Regression test for https://github.com/google/jax/issues/20428
|
||||
|
@ -31,7 +31,6 @@ from jax._src import maps
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.maps import xmap
|
||||
from jax.experimental import io_callback
|
||||
from jax.experimental import pjit
|
||||
@ -803,7 +802,6 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
ValueError, "Pure callbacks do not support JVP."):
|
||||
f(2.)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 245, "jaxlib version too old")
|
||||
def test_error_propagation(self):
|
||||
def throws_error_fn(x):
|
||||
raise RuntimeError("Errors should propagate.")
|
||||
@ -815,7 +813,6 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(Exception, "Errors should propagate."):
|
||||
print(np.array(f(2.0)), flush=True)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 250, "jaxlib version too old")
|
||||
def test_reentrant_error_propagation(self):
|
||||
reentrant_fn = jax.jit(jnp.sin).lower(2.0).compile()
|
||||
|
||||
|
@ -23,7 +23,6 @@ from jax._src import test_util as jtu
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
from jax.experimental.shard_alike import shard_alike
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -63,8 +62,6 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if xla_extension_version < 227:
|
||||
self.skipTest('Requires xla_extension_version >= 227')
|
||||
|
||||
def test_basic(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
|
@ -18,7 +18,6 @@ import functools
|
||||
import pickle
|
||||
import re
|
||||
from typing import TypeVar
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -26,7 +25,6 @@ import jax
|
||||
from jax import flatten_util
|
||||
from jax import tree_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.tree_util import flatten_one_level, prefix_errors
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -233,27 +231,25 @@ TREE_STRINGS = (
|
||||
"PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))",
|
||||
)
|
||||
|
||||
if xla_extension_version >= 259:
|
||||
@pytree_node_dataclass
|
||||
class ADataclass:
|
||||
x: tuple[int, int]
|
||||
y: int
|
||||
|
||||
@pytree_node_dataclass
|
||||
class ADataclass:
|
||||
x: tuple[int, int]
|
||||
y: int
|
||||
@pytree_node_dataclass
|
||||
class ADataclassWithMeta:
|
||||
x: tuple[int, int]
|
||||
y: int
|
||||
z: int = dataclasses.field(metadata={"pytree_node": False})
|
||||
|
||||
@pytree_node_dataclass
|
||||
class ADataclassWithMeta:
|
||||
x: tuple[int, int]
|
||||
y: int
|
||||
z: int = dataclasses.field(metadata={"pytree_node": False})
|
||||
|
||||
TREES += (
|
||||
(ADataclass(x=(1, 2), y=3),),
|
||||
(ADataclassWithMeta(x=(1, 2), y=3, z=4),),
|
||||
)
|
||||
TREE_STRINGS += (
|
||||
"PyTreeDef(CustomNode(ADataclass[()], [(*, *), *]))",
|
||||
"PyTreeDef(CustomNode(ADataclassWithMeta[(4,)], [(*, *), *]))",
|
||||
)
|
||||
TREES += (
|
||||
(ADataclass(x=(1, 2), y=3),),
|
||||
(ADataclassWithMeta(x=(1, 2), y=3, z=4),),
|
||||
)
|
||||
TREE_STRINGS += (
|
||||
"PyTreeDef(CustomNode(ADataclass[()], [(*, *), *]))",
|
||||
"PyTreeDef(CustomNode(ADataclassWithMeta[(4,)], [(*, *), *]))",
|
||||
)
|
||||
|
||||
|
||||
TREES += (
|
||||
@ -829,7 +825,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
|
||||
self.assertLen(leaves, 1)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
|
||||
def testBadFlattenNonTuple(self):
|
||||
t = BadFlattenNonTuple(3, 4)
|
||||
with self.assertRaisesRegex(
|
||||
@ -839,7 +834,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
):
|
||||
tree_util.tree_flatten(t)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
|
||||
def testBadFlattenBadArityTuple(self):
|
||||
t = BadFlattenBadArityTuple(3, 4)
|
||||
with self.assertRaisesRegex(
|
||||
@ -849,7 +843,6 @@ class TreeTest(jtu.JaxTestCase):
|
||||
):
|
||||
tree_util.tree_flatten(t)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
|
||||
def testBadFlattenNonIterableLeaves(self):
|
||||
t = BadFlattenNonIterableLeaves(3, 4)
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -26,7 +26,6 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -56,9 +55,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
|
||||
)
|
||||
self.assertEqual(
|
||||
compile_options.executable_build_options.fdo_profile,
|
||||
b"test_profile" if xla_extension_version >= 242 else "test_profile"
|
||||
)
|
||||
compile_options.executable_build_options.fdo_profile, b"test_profile")
|
||||
|
||||
def test_autofdo_profile(self):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user