Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54

PiperOrigin-RevId: 552816893
This commit is contained in:
Yash Katariya 2023-08-01 08:52:54 -07:00 committed by jax authors
parent 109ed5023d
commit 4ddf6a9a54
19 changed files with 125 additions and 331 deletions

View File

@ -63,7 +63,6 @@ from jax._src.api_util import (
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding
@ -1853,13 +1852,9 @@ def _cpp_pmap(
return out, fastpath_data
if xla_extension_version >= 169:
cpp_mapped_f = pmap_lib.pmap( # type: ignore
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
pytree_registry=tree_util.default_registry)
else:
cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg) # type: ignore
cpp_mapped_f = pmap_lib.pmap( # type: ignore
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)
pmap_f = wraps(fun)(cpp_mapped_f)

View File

@ -22,7 +22,6 @@ import numpy as np
from jax._src.config import config
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
@ -196,10 +195,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
def _hash_executable_build_options(hash_obj, executable_obj):
if xla_extension_version > 165:
expected_options = 11
else:
expected_options = 10
expected_options = 11
# Ignore private and built-in methods. These can unexpectedly change and lead
# to false positives, e.g. when different Python versions include different
# built-ins.
@ -228,7 +224,7 @@ def _hash_executable_build_options(hash_obj, executable_obj):
_hash_bool_list(
hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output
)
if xla_extension_version > 165 and executable_obj.fdo_profile is not None:
if executable_obj.fdo_profile is not None:
_hash_string(hash_obj, executable_obj.fdo_profile)

View File

@ -41,10 +41,8 @@ from jax._src import xla_bridge as xb
from jax._src.config import config
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import mlir_api_version
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
@ -972,10 +970,7 @@ def lower_jaxpr_to_fun(
in zip(replicated_args, input_types)]
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
if replicated:
if xla_extension_version < 172:
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
else:
attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)
attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)
if use_sharding_annotations and ir_arg_shardings is not None:
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
@ -1001,12 +996,6 @@ def lower_jaxpr_to_fun(
for attrs in token_arg_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)
if mlir_api_version < 54 and arg_names:
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:]
for attrs, name_ in zip(named_arg_attrs, arg_names):
if name_:
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
@ -1032,7 +1021,7 @@ def lower_jaxpr_to_fun(
func_op.result_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in result_attrs])
if mlir_api_version >= 54 and arg_names:
if arg_names:
arg_locs = [ir.Location.unknown()] * (num_dim_vars + num_tokens)
for n in arg_names:
arg_locs.append(ir.Location.name(n) if n else ir.Location.unknown())
@ -1907,14 +1896,9 @@ def _emit_tpu_python_callback(
callback.__name__, sharding=sharding)
outputs.append(out)
recv_channels.append(channel)
if xla_extension_version < 161:
opaque = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels) # type: ignore # pylint: disable=missing-parameter
else:
opaque = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter
opaque = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter
ctx.module_context.add_host_callback(opaque)
return outputs, token, opaque
@ -2183,18 +2167,9 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
shape polymorphism, runs shape refinement to resolve all the dynamic shapes.
Then verifies that there are no more dynamic shapes in the module.
"""
if xc.mlir_api_version >= 53:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module), enable_shape_assertions=True,
validate_static_shapes=True)
elif xc.mlir_api_version == 52:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module), enable_shape_assertions=True)
elif xc.mlir_api_version >= 50:
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module))
else:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module), enable_shape_assertions=True,
validate_static_shapes=True)
context = make_ir_context()
with context:

View File

@ -57,7 +57,6 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
@ -2730,12 +2729,8 @@ class MeshExecutable(stages.XlaExecutable):
fastpath_data = None
return outs, fastpath_data
if xla_extension_version >= 169:
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.default_registry) # type: ignore
else:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.default_registry) # type: ignore
def check_arg_avals_for_call(ref_avals, arg_avals,

View File

@ -342,35 +342,20 @@ def _approx_top_k_lowering(ctx, operand, *, k,
if fallback:
backend_config["is_fallback"] = mlir.ir.BoolAttr.get(fallback)
if xc.mlir_api_version >= 51: # jaxlib >= 0.4.14
if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
result_shapes = None
else:
result_shapes = [
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
for aval_out in ctx.avals_out]
out = mlir.custom_call(
"ApproxTopK",
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
[operand, iota, init_val, init_arg],
called_computations=[comparator.name.value],
backend_config=backend_config,
result_shapes=result_shapes)
if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
result_shapes = None
else:
# Older versions do not support has_side_effect attribute; we just use
# the old lowering code.
if any(not core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
raise ValueError("approx_top_k not supported with shape polymorphism; "
"try upgrading jaxlib")
out = hlo.CustomCallOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
[operand, iota, init_val, init_arg],
call_target_name=b"ApproxTopK",
called_computations=mlir.ir.ArrayAttr.get(
[mlir.ir.FlatSymbolRefAttr.get(comparator.name.value)]))
backend_config_attr = mlir.ir.DictAttr.get(backend_config,
ctx.module_context.context)
out.operation.attributes["mhlo.backend_config"] = backend_config_attr
result_shapes = [
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
for aval_out in ctx.avals_out]
out = mlir.custom_call(
"ApproxTopK",
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
[operand, iota, init_val, init_arg],
called_computations=[comparator.name.value],
backend_config=backend_config,
result_shapes=result_shapes)
return out.results

View File

@ -4194,9 +4194,6 @@ top_k_p.def_abstract_eval(_top_k_abstract_eval)
def _top_k_lower(ctx, operand, k):
if core.is_constant_dim(k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
if xla_client.mlir_api_version < 54:
# TODO: https://github.com/openxla/stablehlo/issues/1396
raise ValueError("native serialization with shape polymorphism not implemented for top_k")
k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
out_values_aval, out_indices_aval, = ctx.avals_out
return mlir.custom_call(

View File

@ -54,7 +54,6 @@ from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
@ -256,17 +255,11 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
return outs, fastpath_data
if xla_extension_version >= 169:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
fun, cache_miss, static_argnums, static_argnames, # type: ignore
donate_argnums, tree_util.default_registry, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
fun, cache_miss, static_argnums, static_argnames, # type: ignore
donate_argnums, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
fun, cache_miss, static_argnums, static_argnames, # type: ignore
donate_argnums, tree_util.default_registry, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
@ -1207,13 +1200,9 @@ def _pjit_call_impl(*args, jaxpr,
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
if xla_extension_version >= 169:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.default_registry,
_get_cpp_global_cache(has_explicit_sharding))(*args)
else:
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
_get_cpp_global_cache(has_explicit_sharding))(*args)
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.default_registry,
_get_cpp_global_cache(has_explicit_sharding))(*args)
pjit_p.def_impl(_pjit_call_impl)

View File

@ -36,7 +36,6 @@ from jax._src import util
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.partition_spec import PartitionSpec
import numpy as np
@ -108,15 +107,10 @@ class XLACompatibleSharding(sharding.Sharding):
def is_equivalent_to(self: XLACompatibleSharding, # type: ignore
other: XLACompatibleSharding, ndim: int) -> bool:
try:
if xla_extension_version >= 168:
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
other._to_xla_hlo_sharding(ndim))
and self._device_assignment == other._device_assignment and
self.memory_kind == other.memory_kind)
else:
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
other._to_xla_hlo_sharding(ndim))
and self._device_assignment == other._device_assignment)
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
other._to_xla_hlo_sharding(ndim))
and self._device_assignment == other._device_assignment and
self.memory_kind == other.memory_kind)
# NotImplementedError is raised by PmapSharding because it can't lower
# to OpSharding. So if `other` is a PmapSharding, default to a strict
# equality check.
@ -213,17 +207,14 @@ class NamedSharding(XLACompatibleSharding):
self._preprocess()
def __reduce__(self):
if xla_extension_version >= 168:
return (
type(self),
(self.mesh, self.spec),
{'memory_kind': self.memory_kind},
)
else:
return type(self), (self.mesh, self.spec)
return (
type(self),
(self.mesh, self.spec),
{'memory_kind': self.memory_kind},
)
def _preprocess(self):
if xla_extension_version >= 170 and self.memory_kind is not None:
if self.memory_kind is not None:
# Will error if memory_kind does not exist on the device.
self.mesh.devices.flat[0].memory(self.memory_kind)
@ -242,18 +233,12 @@ class NamedSharding(XLACompatibleSharding):
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
def __repr__(self):
if xla_extension_version >= 168:
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec}{mem})'
else:
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec})'
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec}{mem})'
def __hash__(self):
if not hasattr(self, '_hash'):
if xla_extension_version >= 168:
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
else:
self._hash = hash((self.mesh, self._parsed_pspec))
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
return self._hash
def __eq__(self, other):
@ -262,16 +247,11 @@ class NamedSharding(XLACompatibleSharding):
if id(self) == id(other):
return True
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
if xla_extension_version >= 168:
if (id(self.mesh) == id(other.mesh) and
self.memory_kind == other.memory_kind and parsed_pspec_equal):
return True
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind
and parsed_pspec_equal)
else:
if id(self.mesh) == id(other.mesh) and parsed_pspec_equal:
return True
return self.mesh == other.mesh and parsed_pspec_equal
if (id(self.mesh) == id(other.mesh) and
self.memory_kind == other.memory_kind and parsed_pspec_equal):
return True
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind
and parsed_pspec_equal)
def is_compatible_aval(self, aval_shape: Shape):
assert self._parsed_pspec is not None
@ -285,12 +265,8 @@ class NamedSharding(XLACompatibleSharding):
@classmethod
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
if xla_extension_version >= 168:
return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
else:
return cls(mesh, parsed_pspec.get_partition_spec(),
_parsed_pspec=parsed_pspec)
return cls(mesh, parsed_pspec.get_partition_spec(),
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
@property
def device_set(self) -> set[Device]:
@ -377,24 +353,15 @@ class SingleDeviceSharding(XLACompatibleSharding):
self._memory_kind = memory_kind
def __reduce__(self):
if xla_extension_version >= 168:
return type(self), (self._device,), {'memory_kind': self._memory_kind}
else:
return type(self), (self._device,)
return type(self), (self._device,), {'memory_kind': self._memory_kind}
def __repr__(self):
if xla_extension_version >= 168:
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f"SingleDeviceSharding(device={repr(self._device)}{mem})"
else:
return f"SingleDeviceSharding(device={repr(self._device)})"
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f"SingleDeviceSharding(device={repr(self._device)}{mem})"
def __hash__(self):
if not hasattr(self, '_hash'):
if xla_extension_version >= 168:
self._hash = hash((self._device, self._memory_kind))
else:
self._hash = hash(self._device)
self._hash = hash((self._device, self._memory_kind))
return self._hash
def __eq__(self, other):
@ -402,11 +369,8 @@ class SingleDeviceSharding(XLACompatibleSharding):
return False
if id(self) == id(other):
return True
if xla_extension_version >= 168:
return (self._device == other._device and
self._memory_kind == other._memory_kind)
else:
return self._device == other._device
return (self._device == other._device and
self._memory_kind == other._memory_kind)
@property
def device_set(self) -> set[Device]:
@ -414,10 +378,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
@property
def memory_kind(self) -> str | None:
if xla_extension_version >= 168:
return self._memory_kind
else:
return None
return self._memory_kind
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
return {self._device: (slice(None),) * len(global_shape)}
@ -671,10 +632,7 @@ class PositionalSharding(XLACompatibleSharding):
@property
def memory_kind(self) -> str | None:
if xla_extension_version >= 168:
return self._memory_kind
else:
return None
return self._memory_kind
@functools.cached_property
def is_fully_replicated(self) -> bool:
@ -753,22 +711,12 @@ class GSPMDSharding(XLACompatibleSharding):
self._hlo_sharding = op_sharding
self._memory_kind = memory_kind
if xla_extension_version < 159:
@property
def _hlo_sharding(self): # type: ignore
if isinstance(self._op_sharding, xc.OpSharding): # type: ignore
return xc.HloSharding.from_proto(self._op_sharding) # type: ignore
return self._op_sharding # type: ignore
def __reduce__(self):
if xla_extension_version >= 168:
return (
type(self),
(self._devices, self._hlo_sharding.to_proto()),
{'memory_kind': self._memory_kind},
)
else:
return type(self), (self._devices, self._hlo_sharding.to_proto())
return (
type(self),
(self._devices, self._hlo_sharding.to_proto()),
{'memory_kind': self._memory_kind},
)
@functools.cached_property
def _hlo_sharding_hash(self):
@ -779,29 +727,19 @@ class GSPMDSharding(XLACompatibleSharding):
return False
if id(self) == id(other):
return True
if xla_extension_version >= 168:
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self._devices == other._devices and
self._memory_kind == other._memory_kind)
else:
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self._devices == other._devices)
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
and self._devices == other._devices and
self._memory_kind == other._memory_kind)
def __hash__(self):
if not hasattr(self, '_hash'):
if xla_extension_version >= 168:
self._hash = hash((self._devices, self._hlo_sharding_hash,
self._memory_kind))
else:
self._hash = hash((self._devices, self._hlo_sharding_hash))
self._hash = hash((self._devices, self._hlo_sharding_hash,
self._memory_kind))
return self._hash
def __repr__(self):
if xla_extension_version >= 168:
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f'GSPMDSharding({repr(self._hlo_sharding)}{mem})'
else:
return f'GSPMDSharding({repr(self._hlo_sharding)})'
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
return f'GSPMDSharding({repr(self._hlo_sharding)}{mem})'
def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding)
@ -817,10 +755,7 @@ class GSPMDSharding(XLACompatibleSharding):
@property
def memory_kind(self) -> str | None:
if xla_extension_version >= 168:
return self._memory_kind
else:
return None
return self._memory_kind
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
@ -842,11 +777,8 @@ class GSPMDSharding(XLACompatibleSharding):
@classmethod
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
if xla_extension_version >= 168:
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
memory_kind=memory_kind)
else:
return cls(tuple(device_assignment), get_replicated_hlo_sharding())
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
memory_kind=memory_kind)
class AUTO:

View File

@ -26,7 +26,6 @@ import warnings
from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.util import unzip2
@ -41,14 +40,11 @@ PyTreeDef = pytree.PyTreeDef
# TODO(phawkins): make this unconditional when jaxlib 0.4.14 is the minimum.
default_registry: pytree.PyTreeRegistry | None
if xla_extension_version >= 169:
default_registry = pytree.default_registry()
# Set __module__ and __name__, which allow this registry to be pickled by
# reference.
default_registry.__module__ = __name__
default_registry.__name__ = "default_registry"
else:
default_registry = None
default_registry = pytree.default_registry()
# Set __module__ and __name__, which allow this registry to be pickled by
# reference.
default_registry.__module__ = __name__
default_registry.__name__ = "default_registry"
def tree_flatten(tree: Any,
is_leaf: Callable[[Any], bool] | None = None

View File

@ -40,7 +40,6 @@ from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import bool_env, config, int_env
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import traceback_util
from jax._src import util
@ -132,7 +131,7 @@ def get_compile_options(
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if xla_extension_version > 165 and fdo_profile is not None:
if fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
@ -270,23 +269,13 @@ def make_gpu_client(
if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")}
if xla_extension_version < 160:
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
platform_name=platform_name,
allowed_devices=allowed_devices,
)
else:
# Remove `type: ignore` when the min jaxlib version (xla_extension_version)
# >= 160.
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
platform_name=platform_name,
allowed_devices=allowed_devices,
) # type: ignore
return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
node_id=distributed.global_state.process_id,
num_nodes=distributed.global_state.num_processes,
platform_name=platform_name,
allowed_devices=allowed_devices,
)
if hasattr(xla_client, "make_gpu_client"):
@ -470,20 +459,17 @@ def register_plugin(
)
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
if xla_extension_version < 165:
return xla_client.make_c_api_client(plugin_name, options)
else:
if distributed.global_state.client is None:
return xla_client.make_c_api_client(plugin_name, options, None)
distribute_options = {
'node_id': distributed.global_state.process_id,
'num_nodes': distributed.global_state.num_processes,
}
if options is not None:
distribute_options.update(options)
return xla_client.make_c_api_client(
plugin_name, distribute_options, distributed.global_state.client
)
if distributed.global_state.client is None:
return xla_client.make_c_api_client(plugin_name, options, None)
distribute_options = {
'node_id': distributed.global_state.process_id,
'num_nodes': distributed.global_state.num_processes,
}
if options is not None:
distribute_options.update(options)
return xla_client.make_c_api_client(
plugin_name, distribute_options, distributed.global_state.client
)
logger.debug(

View File

@ -842,11 +842,6 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
if (exported.serialization_version >= 7 and
exported.uses_shape_polymorphism):
if xla_client.mlir_api_version < 52:
raise NotImplementedError(
"Current jaxlib does not support shape polymorphism with serialization version >= 7")
@jax.custom_vjp
def f_flat(*args_flat):
return call_exported_p.bind(*args_flat, exported=exported)

View File

@ -298,9 +298,6 @@ class JaxExportTest(jtu.JaxTestCase):
ValueError,
f"The requested jax_serialization version {v} is outside the range of supported versions"))
if (xc.mlir_api_version <= 51 and
config.jax_serialization_version >= 7):
raise unittest.SkipTest("Not supported in old jaxlib")
exp = jax_export.export(jnp.sin)(
jax_export.poly_spec((3, 4), np.float32, "w, h"))
# Peek at the module
@ -347,15 +344,10 @@ class JaxExportTest(jtu.JaxTestCase):
arg_shape=(3, 4, 12), arg_dtype=np.float32,
expect_error=None): # If given, error from running the exported module
if xc.mlir_api_version <= 51:
raise unittest.SkipTest("Not supported in old jaxlib")
def f(x): # x: f32[poly_spec]
return jnp.reshape(x, (-1, x.shape[1]))
if xc.mlir_api_version <= 51:
disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),)
else:
disabled_checks = ()
disabled_checks = ()
exp_f = jax_export.export(f, disabled_checks=disabled_checks)(
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
@ -453,8 +445,6 @@ class JaxExportTest(jtu.JaxTestCase):
expect_error_outer_exp=None,
expect_error_run=None):
# Polymorphic export called with static or polymorphic shapes
if xc.mlir_api_version <= 51:
raise unittest.SkipTest("Not supported in old jaxlib")
def inner(x): # x: inner_poly_spec
return jnp.reshape(x, (-1, x.shape[1]))

View File

@ -16,7 +16,7 @@
# eval()-ed by setup.py, so it should not have any dependencies.
__version__ = "0.4.15"
_minimum_jaxlib_version = "0.4.11"
_minimum_jaxlib_version = "0.4.14"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

View File

@ -60,7 +60,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
@ -1209,12 +1208,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
def test_jit_lower_compile_cost_analysis(self):
f = self.jit(lambda x: x).lower(1.).compile()
g = self.jit(lambda x: x + 4).lower(1.).compile()
if xla_extension_version >= 164:
self.assertIsNotNone(f.cost_analysis())
self.assertIsNotNone(g.cost_analysis())
else:
f.cost_analysis() # doesn't raise
g.cost_analysis() # doesn't raise
self.assertIsNotNone(f.cost_analysis())
self.assertIsNotNone(g.cost_analysis())
@jtu.skip_on_xla_cpu_mlir
def test_jit_lower_compile_memory_analysis(self):
@ -1283,8 +1278,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
@unittest.skipIf(xla_extension_version < 171,
'Test requires xla_extension_version >= 171')
def test_jit_lower_compile_with_compiler_options(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.

View File

@ -21,7 +21,6 @@ from jax import config
import jax.dlpack
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
import numpy as np
@ -164,7 +163,6 @@ class DLPackTest(jtu.JaxTestCase):
self.assertAllClose(x_np, x_jax)
@unittest.skipIf(xla_extension_version < 163, "Test requires jaxlib 0.4.13")
class CudaArrayInterfaceTest(jtu.JaxTestCase):
def setUp(self):

View File

@ -28,7 +28,6 @@ from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.config import compilation_cache_include_metadata_in_key
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
import numpy as np
@ -271,8 +270,7 @@ class CacheKeyTest(jtu.JaxTestCase):
compile_options.executable_build_options.device_assignment = (
device_assignment
)
if xla_extension_version > 165:
compile_options.executable_build_options.fdo_profile = b"test_profile"
compile_options.executable_build_options.fdo_profile = b"test_profile"
return compile_options
def get_hashed_value(self, hash_function, hash_function_input):

View File

@ -13,15 +13,12 @@
# limitations under the License.
"""Tests for release_backend_clients."""
import unittest
from absl.testing import absltest
import jax
from jax import config
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_extension_version
config.parse_flags_with_absl()
@ -29,9 +26,6 @@ config.parse_flags_with_absl()
class ClearBackendsTest(jtu.JaxTestCase):
def test_clear_backends(self):
if xla_extension_version < 164 and xb.using_pjrt_c_api():
raise unittest.SkipTest('test crashes runtime with PJRT C API')
g = jax.jit(lambda x, y: x * y)
self.assertEqual(g(1, 2), 2)
self.assertNotEmpty(xb.get_backend().live_executables())

View File

@ -13,7 +13,6 @@
# limitations under the License.
from functools import partial
import unittest
import glob
import logging
import math
@ -26,7 +25,6 @@ from jax import config
from jax._src import test_util as jtu
from jax.sharding import NamedSharding
from jax.experimental import profiler as exp_profiler
from jax._src.lib import xla_extension_version
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
@ -37,8 +35,6 @@ config.parse_flags_with_absl()
@jtu.pytest_mark_if_available('multiaccelerator')
class PgleTest(jtu.JaxTestCase):
@unittest.skipIf(xla_extension_version < 169,
'Requires xla_extension_version >= 169')
def testPassingFDOProfile(self):
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(

View File

@ -22,7 +22,6 @@ from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import xla
from jax._src.config import config
@ -51,13 +50,12 @@ class XlaBridgeTest(jtu.JaxTestCase):
expected_device_assignment)
def test_set_fdo_profile(self):
if xla_extension_version > 166:
compile_options = xb.get_compile_options(
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
)
self.assertEqual(
compile_options.executable_build_options.fdo_profile, "test_profile"
)
compile_options = xb.get_compile_options(
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
)
self.assertEqual(
compile_options.executable_build_options.fdo_profile, "test_profile"
)
def test_parameter_replication_default(self):
c = xc.XlaBuilder("test")
@ -125,10 +123,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertEqual(registration.priority, 400)
self.assertTrue(registration.experimental)
mock_plugin_loaded.assert_called_once_with("name1")
if xla_extension_version < 165:
mock_make.assert_called_once_with("name1", None)
else:
mock_make.assert_called_once_with("name1", None, None)
mock_make.assert_called_once_with("name1", None, None)
def test_register_plugin_with_config(self):
test_json_file_path = os.path.join(
@ -149,27 +144,16 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertEqual(registration.priority, 400)
self.assertTrue(registration.experimental)
mock_plugin_loaded.assert_called_once_with("name1")
if xla_extension_version < 165:
mock_make.assert_called_once_with(
"name1",
{
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0,
},
)
else:
mock_make.assert_called_once_with(
"name1",
{
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0,
},
None,
)
mock_make.assert_called_once_with(
"name1",
{
"int_option": 64,
"int_list_option": [32, 64],
"string_option": "string",
"float_option": 1.0,
},
None,
)
class GetBackendTest(jtu.JaxTestCase):