mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Bump minimum_jaxlib_version to 0.4.14. xla_extension_version
is 174 and mlir_api_version
is 54
PiperOrigin-RevId: 552816893
This commit is contained in:
parent
109ed5023d
commit
4ddf6a9a54
@ -63,7 +63,6 @@ from jax._src.api_util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
@ -1853,13 +1852,9 @@ def _cpp_pmap(
|
||||
|
||||
return out, fastpath_data
|
||||
|
||||
if xla_extension_version >= 169:
|
||||
cpp_mapped_f = pmap_lib.pmap( # type: ignore
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
|
||||
pytree_registry=tree_util.default_registry)
|
||||
else:
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg) # type: ignore
|
||||
cpp_mapped_f = pmap_lib.pmap( # type: ignore
|
||||
fun, cache_miss, static_broadcasted_tuple, pxla.shard_arg,
|
||||
pytree_registry=tree_util.default_registry)
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
pmap_f = wraps(fun)(cpp_mapped_f)
|
||||
|
@ -22,7 +22,6 @@ import numpy as np
|
||||
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -196,10 +195,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
|
||||
|
||||
def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
if xla_extension_version > 165:
|
||||
expected_options = 11
|
||||
else:
|
||||
expected_options = 10
|
||||
expected_options = 11
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
@ -228,7 +224,7 @@ def _hash_executable_build_options(hash_obj, executable_obj):
|
||||
_hash_bool_list(
|
||||
hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output
|
||||
)
|
||||
if xla_extension_version > 165 and executable_obj.fdo_profile is not None:
|
||||
if executable_obj.fdo_profile is not None:
|
||||
_hash_string(hash_obj, executable_obj.fdo_profile)
|
||||
|
||||
|
||||
|
@ -41,10 +41,8 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import mlir_api_version
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
@ -972,10 +970,7 @@ def lower_jaxpr_to_fun(
|
||||
in zip(replicated_args, input_types)]
|
||||
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
|
||||
if replicated:
|
||||
if xla_extension_version < 172:
|
||||
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
|
||||
else:
|
||||
attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)
|
||||
attrs["mhlo.is_same_data_across_replicas"] = ir.BoolAttr.get(True)
|
||||
|
||||
if use_sharding_annotations and ir_arg_shardings is not None:
|
||||
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
|
||||
@ -1001,12 +996,6 @@ def lower_jaxpr_to_fun(
|
||||
for attrs in token_arg_attrs:
|
||||
attrs["jax.token"] = ir.BoolAttr.get(True)
|
||||
|
||||
if mlir_api_version < 54 and arg_names:
|
||||
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:]
|
||||
for attrs, name_ in zip(named_arg_attrs, arg_names):
|
||||
if name_:
|
||||
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
|
||||
|
||||
func_op.arg_attrs = ir.ArrayAttr.get(
|
||||
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
|
||||
|
||||
@ -1032,7 +1021,7 @@ def lower_jaxpr_to_fun(
|
||||
func_op.result_attrs = ir.ArrayAttr.get(
|
||||
[ir.DictAttr.get(attrs) for attrs in result_attrs])
|
||||
|
||||
if mlir_api_version >= 54 and arg_names:
|
||||
if arg_names:
|
||||
arg_locs = [ir.Location.unknown()] * (num_dim_vars + num_tokens)
|
||||
for n in arg_names:
|
||||
arg_locs.append(ir.Location.name(n) if n else ir.Location.unknown())
|
||||
@ -1907,14 +1896,9 @@ def _emit_tpu_python_callback(
|
||||
callback.__name__, sharding=sharding)
|
||||
outputs.append(out)
|
||||
recv_channels.append(channel)
|
||||
if xla_extension_version < 161:
|
||||
opaque = backend.make_python_callback_from_host_send_and_recv(
|
||||
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
||||
recv_channels) # type: ignore # pylint: disable=missing-parameter
|
||||
else:
|
||||
opaque = backend.make_python_callback_from_host_send_and_recv(
|
||||
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
||||
recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter
|
||||
opaque = backend.make_python_callback_from_host_send_and_recv(
|
||||
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
||||
recv_channels, pickle_util.dumps) # type: ignore # pylint: disable=missing-parameter
|
||||
ctx.module_context.add_host_callback(opaque)
|
||||
return outputs, token, opaque
|
||||
|
||||
@ -2183,18 +2167,9 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
|
||||
shape polymorphism, runs shape refinement to resolve all the dynamic shapes.
|
||||
Then verifies that there are no more dynamic shapes in the module.
|
||||
"""
|
||||
if xc.mlir_api_version >= 53:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module), enable_shape_assertions=True,
|
||||
validate_static_shapes=True)
|
||||
elif xc.mlir_api_version == 52:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module), enable_shape_assertions=True)
|
||||
elif xc.mlir_api_version >= 50:
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module))
|
||||
else:
|
||||
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")
|
||||
refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
|
||||
module_to_bytecode(module), enable_shape_assertions=True,
|
||||
validate_static_shapes=True)
|
||||
|
||||
context = make_ir_context()
|
||||
with context:
|
||||
|
@ -57,7 +57,6 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -2730,12 +2729,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
if xla_extension_version >= 169:
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.default_registry) # type: ignore
|
||||
else:
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], []) # type: ignore
|
||||
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.default_registry) # type: ignore
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
|
@ -342,35 +342,20 @@ def _approx_top_k_lowering(ctx, operand, *, k,
|
||||
if fallback:
|
||||
backend_config["is_fallback"] = mlir.ir.BoolAttr.get(fallback)
|
||||
|
||||
if xc.mlir_api_version >= 51: # jaxlib >= 0.4.14
|
||||
if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
|
||||
result_shapes = None
|
||||
else:
|
||||
result_shapes = [
|
||||
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
|
||||
for aval_out in ctx.avals_out]
|
||||
|
||||
out = mlir.custom_call(
|
||||
"ApproxTopK",
|
||||
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
[operand, iota, init_val, init_arg],
|
||||
called_computations=[comparator.name.value],
|
||||
backend_config=backend_config,
|
||||
result_shapes=result_shapes)
|
||||
if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
|
||||
result_shapes = None
|
||||
else:
|
||||
# Older versions do not support has_side_effect attribute; we just use
|
||||
# the old lowering code.
|
||||
if any(not core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
|
||||
raise ValueError("approx_top_k not supported with shape polymorphism; "
|
||||
"try upgrading jaxlib")
|
||||
out = hlo.CustomCallOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
[operand, iota, init_val, init_arg],
|
||||
call_target_name=b"ApproxTopK",
|
||||
called_computations=mlir.ir.ArrayAttr.get(
|
||||
[mlir.ir.FlatSymbolRefAttr.get(comparator.name.value)]))
|
||||
backend_config_attr = mlir.ir.DictAttr.get(backend_config,
|
||||
ctx.module_context.context)
|
||||
out.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
||||
result_shapes = [
|
||||
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
|
||||
for aval_out in ctx.avals_out]
|
||||
|
||||
out = mlir.custom_call(
|
||||
"ApproxTopK",
|
||||
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
[operand, iota, init_val, init_arg],
|
||||
called_computations=[comparator.name.value],
|
||||
backend_config=backend_config,
|
||||
result_shapes=result_shapes)
|
||||
|
||||
return out.results
|
||||
|
||||
|
@ -4194,9 +4194,6 @@ top_k_p.def_abstract_eval(_top_k_abstract_eval)
|
||||
def _top_k_lower(ctx, operand, k):
|
||||
if core.is_constant_dim(k):
|
||||
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
|
||||
if xla_client.mlir_api_version < 54:
|
||||
# TODO: https://github.com/openxla/stablehlo/issues/1396
|
||||
raise ValueError("native serialization with shape polymorphism not implemented for top_k")
|
||||
k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
|
||||
out_values_aval, out_indices_aval, = ctx.avals_out
|
||||
return mlir.custom_call(
|
||||
|
@ -54,7 +54,6 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||
@ -256,17 +255,11 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
fastpath_data = _get_fastpath_data(executable, out_tree, args_flat, out_flat)
|
||||
return outs, fastpath_data
|
||||
|
||||
if xla_extension_version >= 169:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
|
||||
fun, cache_miss, static_argnums, static_argnames, # type: ignore
|
||||
donate_argnums, tree_util.default_registry, # type: ignore
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
else:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
|
||||
fun, cache_miss, static_argnums, static_argnames, # type: ignore
|
||||
donate_argnums, _get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type: ignore
|
||||
fun, cache_miss, static_argnums, static_argnames, # type: ignore
|
||||
donate_argnums, tree_util.default_registry, # type: ignore
|
||||
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
|
||||
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
@ -1207,13 +1200,9 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
in_shardings, out_shardings, None, None)
|
||||
if xla_extension_version >= 169:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.default_registry,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
else:
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.default_registry,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
||||
|
@ -36,7 +36,6 @@ from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
|
||||
import numpy as np
|
||||
@ -108,15 +107,10 @@ class XLACompatibleSharding(sharding.Sharding):
|
||||
def is_equivalent_to(self: XLACompatibleSharding, # type: ignore
|
||||
other: XLACompatibleSharding, ndim: int) -> bool:
|
||||
try:
|
||||
if xla_extension_version >= 168:
|
||||
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
||||
other._to_xla_hlo_sharding(ndim))
|
||||
and self._device_assignment == other._device_assignment and
|
||||
self.memory_kind == other.memory_kind)
|
||||
else:
|
||||
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
||||
other._to_xla_hlo_sharding(ndim))
|
||||
and self._device_assignment == other._device_assignment)
|
||||
return (are_op_shardings_equal(self._to_xla_hlo_sharding(ndim),
|
||||
other._to_xla_hlo_sharding(ndim))
|
||||
and self._device_assignment == other._device_assignment and
|
||||
self.memory_kind == other.memory_kind)
|
||||
# NotImplementedError is raised by PmapSharding because it can't lower
|
||||
# to OpSharding. So if `other` is a PmapSharding, default to a strict
|
||||
# equality check.
|
||||
@ -213,17 +207,14 @@ class NamedSharding(XLACompatibleSharding):
|
||||
self._preprocess()
|
||||
|
||||
def __reduce__(self):
|
||||
if xla_extension_version >= 168:
|
||||
return (
|
||||
type(self),
|
||||
(self.mesh, self.spec),
|
||||
{'memory_kind': self.memory_kind},
|
||||
)
|
||||
else:
|
||||
return type(self), (self.mesh, self.spec)
|
||||
return (
|
||||
type(self),
|
||||
(self.mesh, self.spec),
|
||||
{'memory_kind': self.memory_kind},
|
||||
)
|
||||
|
||||
def _preprocess(self):
|
||||
if xla_extension_version >= 170 and self.memory_kind is not None:
|
||||
if self.memory_kind is not None:
|
||||
# Will error if memory_kind does not exist on the device.
|
||||
self.mesh.devices.flat[0].memory(self.memory_kind)
|
||||
|
||||
@ -242,18 +233,12 @@ class NamedSharding(XLACompatibleSharding):
|
||||
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
||||
|
||||
def __repr__(self):
|
||||
if xla_extension_version >= 168:
|
||||
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
|
||||
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec}{mem})'
|
||||
else:
|
||||
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec})'
|
||||
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
|
||||
return f'NamedSharding(mesh={dict(self.mesh.shape)}, spec={self.spec}{mem})'
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
if xla_extension_version >= 168:
|
||||
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
|
||||
else:
|
||||
self._hash = hash((self.mesh, self._parsed_pspec))
|
||||
self._hash = hash((self.mesh, self.memory_kind, self._parsed_pspec))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -262,16 +247,11 @@ class NamedSharding(XLACompatibleSharding):
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
parsed_pspec_equal = self._parsed_pspec == other._parsed_pspec
|
||||
if xla_extension_version >= 168:
|
||||
if (id(self.mesh) == id(other.mesh) and
|
||||
self.memory_kind == other.memory_kind and parsed_pspec_equal):
|
||||
return True
|
||||
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind
|
||||
and parsed_pspec_equal)
|
||||
else:
|
||||
if id(self.mesh) == id(other.mesh) and parsed_pspec_equal:
|
||||
return True
|
||||
return self.mesh == other.mesh and parsed_pspec_equal
|
||||
if (id(self.mesh) == id(other.mesh) and
|
||||
self.memory_kind == other.memory_kind and parsed_pspec_equal):
|
||||
return True
|
||||
return (self.mesh == other.mesh and self.memory_kind == other.memory_kind
|
||||
and parsed_pspec_equal)
|
||||
|
||||
def is_compatible_aval(self, aval_shape: Shape):
|
||||
assert self._parsed_pspec is not None
|
||||
@ -285,12 +265,8 @@ class NamedSharding(XLACompatibleSharding):
|
||||
|
||||
@classmethod
|
||||
def _from_parsed_pspec(cls, mesh, parsed_pspec, *, memory_kind=None):
|
||||
if xla_extension_version >= 168:
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
|
||||
else:
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
_parsed_pspec=parsed_pspec)
|
||||
return cls(mesh, parsed_pspec.get_partition_spec(),
|
||||
memory_kind=memory_kind, _parsed_pspec=parsed_pspec)
|
||||
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
@ -377,24 +353,15 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
self._memory_kind = memory_kind
|
||||
|
||||
def __reduce__(self):
|
||||
if xla_extension_version >= 168:
|
||||
return type(self), (self._device,), {'memory_kind': self._memory_kind}
|
||||
else:
|
||||
return type(self), (self._device,)
|
||||
return type(self), (self._device,), {'memory_kind': self._memory_kind}
|
||||
|
||||
def __repr__(self):
|
||||
if xla_extension_version >= 168:
|
||||
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
||||
return f"SingleDeviceSharding(device={repr(self._device)}{mem})"
|
||||
else:
|
||||
return f"SingleDeviceSharding(device={repr(self._device)})"
|
||||
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
||||
return f"SingleDeviceSharding(device={repr(self._device)}{mem})"
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
if xla_extension_version >= 168:
|
||||
self._hash = hash((self._device, self._memory_kind))
|
||||
else:
|
||||
self._hash = hash(self._device)
|
||||
self._hash = hash((self._device, self._memory_kind))
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -402,11 +369,8 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
if xla_extension_version >= 168:
|
||||
return (self._device == other._device and
|
||||
self._memory_kind == other._memory_kind)
|
||||
else:
|
||||
return self._device == other._device
|
||||
return (self._device == other._device and
|
||||
self._memory_kind == other._memory_kind)
|
||||
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
@ -414,10 +378,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
|
||||
@property
|
||||
def memory_kind(self) -> str | None:
|
||||
if xla_extension_version >= 168:
|
||||
return self._memory_kind
|
||||
else:
|
||||
return None
|
||||
return self._memory_kind
|
||||
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]: # type: ignore
|
||||
return {self._device: (slice(None),) * len(global_shape)}
|
||||
@ -671,10 +632,7 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
|
||||
@property
|
||||
def memory_kind(self) -> str | None:
|
||||
if xla_extension_version >= 168:
|
||||
return self._memory_kind
|
||||
else:
|
||||
return None
|
||||
return self._memory_kind
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
@ -753,22 +711,12 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
self._hlo_sharding = op_sharding
|
||||
self._memory_kind = memory_kind
|
||||
|
||||
if xla_extension_version < 159:
|
||||
@property
|
||||
def _hlo_sharding(self): # type: ignore
|
||||
if isinstance(self._op_sharding, xc.OpSharding): # type: ignore
|
||||
return xc.HloSharding.from_proto(self._op_sharding) # type: ignore
|
||||
return self._op_sharding # type: ignore
|
||||
|
||||
def __reduce__(self):
|
||||
if xla_extension_version >= 168:
|
||||
return (
|
||||
type(self),
|
||||
(self._devices, self._hlo_sharding.to_proto()),
|
||||
{'memory_kind': self._memory_kind},
|
||||
)
|
||||
else:
|
||||
return type(self), (self._devices, self._hlo_sharding.to_proto())
|
||||
return (
|
||||
type(self),
|
||||
(self._devices, self._hlo_sharding.to_proto()),
|
||||
{'memory_kind': self._memory_kind},
|
||||
)
|
||||
|
||||
@functools.cached_property
|
||||
def _hlo_sharding_hash(self):
|
||||
@ -779,29 +727,19 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
if xla_extension_version >= 168:
|
||||
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
||||
and self._devices == other._devices and
|
||||
self._memory_kind == other._memory_kind)
|
||||
else:
|
||||
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
||||
and self._devices == other._devices)
|
||||
return (are_op_shardings_equal(self._hlo_sharding, other._hlo_sharding)
|
||||
and self._devices == other._devices and
|
||||
self._memory_kind == other._memory_kind)
|
||||
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
if xla_extension_version >= 168:
|
||||
self._hash = hash((self._devices, self._hlo_sharding_hash,
|
||||
self._memory_kind))
|
||||
else:
|
||||
self._hash = hash((self._devices, self._hlo_sharding_hash))
|
||||
self._hash = hash((self._devices, self._hlo_sharding_hash,
|
||||
self._memory_kind))
|
||||
return self._hash
|
||||
|
||||
def __repr__(self):
|
||||
if xla_extension_version >= 168:
|
||||
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
||||
return f'GSPMDSharding({repr(self._hlo_sharding)}{mem})'
|
||||
else:
|
||||
return f'GSPMDSharding({repr(self._hlo_sharding)})'
|
||||
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
|
||||
return f'GSPMDSharding({repr(self._hlo_sharding)}{mem})'
|
||||
|
||||
def is_compatible_aval(self, aval_shape: Shape):
|
||||
num_ways_dim_sharded, _ = get_num_ways_dim_sharded(self._hlo_sharding)
|
||||
@ -817,10 +755,7 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
|
||||
@property
|
||||
def memory_kind(self) -> str | None:
|
||||
if xla_extension_version >= 168:
|
||||
return self._memory_kind
|
||||
else:
|
||||
return None
|
||||
return self._memory_kind
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
@ -842,11 +777,8 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
|
||||
@classmethod
|
||||
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
|
||||
if xla_extension_version >= 168:
|
||||
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
|
||||
memory_kind=memory_kind)
|
||||
else:
|
||||
return cls(tuple(device_assignment), get_replicated_hlo_sharding())
|
||||
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
|
||||
memory_kind=memory_kind)
|
||||
|
||||
|
||||
class AUTO:
|
||||
|
@ -26,7 +26,6 @@ import warnings
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import unzip2
|
||||
|
||||
@ -41,14 +40,11 @@ PyTreeDef = pytree.PyTreeDef
|
||||
|
||||
# TODO(phawkins): make this unconditional when jaxlib 0.4.14 is the minimum.
|
||||
default_registry: pytree.PyTreeRegistry | None
|
||||
if xla_extension_version >= 169:
|
||||
default_registry = pytree.default_registry()
|
||||
# Set __module__ and __name__, which allow this registry to be pickled by
|
||||
# reference.
|
||||
default_registry.__module__ = __name__
|
||||
default_registry.__name__ = "default_registry"
|
||||
else:
|
||||
default_registry = None
|
||||
default_registry = pytree.default_registry()
|
||||
# Set __module__ and __name__, which allow this registry to be pickled by
|
||||
# reference.
|
||||
default_registry.__module__ = __name__
|
||||
default_registry.__name__ = "default_registry"
|
||||
|
||||
def tree_flatten(tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None
|
||||
|
@ -40,7 +40,6 @@ from jax._src import distributed
|
||||
from jax._src import config as jax_config
|
||||
from jax._src.config import bool_env, config, int_env
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
|
||||
@ -132,7 +131,7 @@ def get_compile_options(
|
||||
build_options = compile_options.executable_build_options
|
||||
build_options.use_spmd_partitioning = use_spmd_partitioning
|
||||
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
|
||||
if xla_extension_version > 165 and fdo_profile is not None:
|
||||
if fdo_profile is not None:
|
||||
build_options.fdo_profile = fdo_profile
|
||||
if use_auto_spmd_partitioning:
|
||||
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
|
||||
@ -270,23 +269,13 @@ def make_gpu_client(
|
||||
if visible_devices != "all":
|
||||
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
||||
|
||||
if xla_extension_version < 160:
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices,
|
||||
)
|
||||
else:
|
||||
# Remove `type: ignore` when the min jaxlib version (xla_extension_version)
|
||||
# >= 160.
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices,
|
||||
) # type: ignore
|
||||
return xla_client.make_gpu_client(
|
||||
distributed_client=distributed.global_state.client,
|
||||
node_id=distributed.global_state.process_id,
|
||||
num_nodes=distributed.global_state.num_processes,
|
||||
platform_name=platform_name,
|
||||
allowed_devices=allowed_devices,
|
||||
)
|
||||
|
||||
|
||||
if hasattr(xla_client, "make_gpu_client"):
|
||||
@ -470,20 +459,17 @@ def register_plugin(
|
||||
)
|
||||
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
|
||||
if xla_extension_version < 165:
|
||||
return xla_client.make_c_api_client(plugin_name, options)
|
||||
else:
|
||||
if distributed.global_state.client is None:
|
||||
return xla_client.make_c_api_client(plugin_name, options, None)
|
||||
distribute_options = {
|
||||
'node_id': distributed.global_state.process_id,
|
||||
'num_nodes': distributed.global_state.num_processes,
|
||||
}
|
||||
if options is not None:
|
||||
distribute_options.update(options)
|
||||
return xla_client.make_c_api_client(
|
||||
plugin_name, distribute_options, distributed.global_state.client
|
||||
)
|
||||
if distributed.global_state.client is None:
|
||||
return xla_client.make_c_api_client(plugin_name, options, None)
|
||||
distribute_options = {
|
||||
'node_id': distributed.global_state.process_id,
|
||||
'num_nodes': distributed.global_state.num_processes,
|
||||
}
|
||||
if options is not None:
|
||||
distribute_options.update(options)
|
||||
return xla_client.make_c_api_client(
|
||||
plugin_name, distribute_options, distributed.global_state.client
|
||||
)
|
||||
|
||||
|
||||
logger.debug(
|
||||
|
@ -842,11 +842,6 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
|
||||
|
||||
def call_exported(exported: Exported) -> Callable[..., jax.Array]:
|
||||
|
||||
if (exported.serialization_version >= 7 and
|
||||
exported.uses_shape_polymorphism):
|
||||
if xla_client.mlir_api_version < 52:
|
||||
raise NotImplementedError(
|
||||
"Current jaxlib does not support shape polymorphism with serialization version >= 7")
|
||||
@jax.custom_vjp
|
||||
def f_flat(*args_flat):
|
||||
return call_exported_p.bind(*args_flat, exported=exported)
|
||||
|
@ -298,9 +298,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
ValueError,
|
||||
f"The requested jax_serialization version {v} is outside the range of supported versions"))
|
||||
|
||||
if (xc.mlir_api_version <= 51 and
|
||||
config.jax_serialization_version >= 7):
|
||||
raise unittest.SkipTest("Not supported in old jaxlib")
|
||||
exp = jax_export.export(jnp.sin)(
|
||||
jax_export.poly_spec((3, 4), np.float32, "w, h"))
|
||||
# Peek at the module
|
||||
@ -347,15 +344,10 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
arg_shape=(3, 4, 12), arg_dtype=np.float32,
|
||||
expect_error=None): # If given, error from running the exported module
|
||||
|
||||
if xc.mlir_api_version <= 51:
|
||||
raise unittest.SkipTest("Not supported in old jaxlib")
|
||||
def f(x): # x: f32[poly_spec]
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
if xc.mlir_api_version <= 51:
|
||||
disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),)
|
||||
else:
|
||||
disabled_checks = ()
|
||||
disabled_checks = ()
|
||||
exp_f = jax_export.export(f, disabled_checks=disabled_checks)(
|
||||
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
|
||||
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
|
||||
@ -453,8 +445,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
expect_error_outer_exp=None,
|
||||
expect_error_run=None):
|
||||
# Polymorphic export called with static or polymorphic shapes
|
||||
if xc.mlir_api_version <= 51:
|
||||
raise unittest.SkipTest("Not supported in old jaxlib")
|
||||
def inner(x): # x: inner_poly_spec
|
||||
return jnp.reshape(x, (-1, x.shape[1]))
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
# eval()-ed by setup.py, so it should not have any dependencies.
|
||||
|
||||
__version__ = "0.4.15"
|
||||
_minimum_jaxlib_version = "0.4.11"
|
||||
_minimum_jaxlib_version = "0.4.14"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
@ -60,7 +60,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.custom_batching
|
||||
@ -1209,12 +1208,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
def test_jit_lower_compile_cost_analysis(self):
|
||||
f = self.jit(lambda x: x).lower(1.).compile()
|
||||
g = self.jit(lambda x: x + 4).lower(1.).compile()
|
||||
if xla_extension_version >= 164:
|
||||
self.assertIsNotNone(f.cost_analysis())
|
||||
self.assertIsNotNone(g.cost_analysis())
|
||||
else:
|
||||
f.cost_analysis() # doesn't raise
|
||||
g.cost_analysis() # doesn't raise
|
||||
self.assertIsNotNone(f.cost_analysis())
|
||||
self.assertIsNotNone(g.cost_analysis())
|
||||
|
||||
@jtu.skip_on_xla_cpu_mlir
|
||||
def test_jit_lower_compile_memory_analysis(self):
|
||||
@ -1283,8 +1278,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 171,
|
||||
'Test requires xla_extension_version >= 171')
|
||||
def test_jit_lower_compile_with_compiler_options(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
|
@ -21,7 +21,6 @@ from jax import config
|
||||
import jax.dlpack
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -164,7 +163,6 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(x_np, x_jax)
|
||||
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 163, "Test requires jaxlib 0.4.13")
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -28,7 +28,6 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import compilation_cache_include_metadata_in_key
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -271,8 +270,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
compile_options.executable_build_options.device_assignment = (
|
||||
device_assignment
|
||||
)
|
||||
if xla_extension_version > 165:
|
||||
compile_options.executable_build_options.fdo_profile = b"test_profile"
|
||||
compile_options.executable_build_options.fdo_profile = b"test_profile"
|
||||
return compile_options
|
||||
|
||||
def get_hashed_value(self, hash_function, hash_function_input):
|
||||
|
@ -13,15 +13,12 @@
|
||||
# limitations under the License.
|
||||
"""Tests for release_backend_clients."""
|
||||
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -29,9 +26,6 @@ config.parse_flags_with_absl()
|
||||
class ClearBackendsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_clear_backends(self):
|
||||
if xla_extension_version < 164 and xb.using_pjrt_c_api():
|
||||
raise unittest.SkipTest('test crashes runtime with PJRT C API')
|
||||
|
||||
g = jax.jit(lambda x, y: x * y)
|
||||
self.assertEqual(g(1, 2), 2)
|
||||
self.assertNotEmpty(xb.get_backend().live_executables())
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
@ -26,7 +25,6 @@ from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.experimental import profiler as exp_profiler
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
@ -37,8 +35,6 @@ config.parse_flags_with_absl()
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PgleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 169,
|
||||
'Requires xla_extension_version >= 169')
|
||||
def testPassingFDOProfile(self):
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
@partial(
|
||||
|
@ -22,7 +22,6 @@ from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src.config import config
|
||||
@ -51,13 +50,12 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
expected_device_assignment)
|
||||
|
||||
def test_set_fdo_profile(self):
|
||||
if xla_extension_version > 166:
|
||||
compile_options = xb.get_compile_options(
|
||||
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
|
||||
)
|
||||
self.assertEqual(
|
||||
compile_options.executable_build_options.fdo_profile, "test_profile"
|
||||
)
|
||||
compile_options = xb.get_compile_options(
|
||||
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
|
||||
)
|
||||
self.assertEqual(
|
||||
compile_options.executable_build_options.fdo_profile, "test_profile"
|
||||
)
|
||||
|
||||
def test_parameter_replication_default(self):
|
||||
c = xc.XlaBuilder("test")
|
||||
@ -125,10 +123,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(registration.priority, 400)
|
||||
self.assertTrue(registration.experimental)
|
||||
mock_plugin_loaded.assert_called_once_with("name1")
|
||||
if xla_extension_version < 165:
|
||||
mock_make.assert_called_once_with("name1", None)
|
||||
else:
|
||||
mock_make.assert_called_once_with("name1", None, None)
|
||||
mock_make.assert_called_once_with("name1", None, None)
|
||||
|
||||
def test_register_plugin_with_config(self):
|
||||
test_json_file_path = os.path.join(
|
||||
@ -149,27 +144,16 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
self.assertEqual(registration.priority, 400)
|
||||
self.assertTrue(registration.experimental)
|
||||
mock_plugin_loaded.assert_called_once_with("name1")
|
||||
if xla_extension_version < 165:
|
||||
mock_make.assert_called_once_with(
|
||||
"name1",
|
||||
{
|
||||
"int_option": 64,
|
||||
"int_list_option": [32, 64],
|
||||
"string_option": "string",
|
||||
"float_option": 1.0,
|
||||
},
|
||||
)
|
||||
else:
|
||||
mock_make.assert_called_once_with(
|
||||
"name1",
|
||||
{
|
||||
"int_option": 64,
|
||||
"int_list_option": [32, 64],
|
||||
"string_option": "string",
|
||||
"float_option": 1.0,
|
||||
},
|
||||
None,
|
||||
)
|
||||
mock_make.assert_called_once_with(
|
||||
"name1",
|
||||
{
|
||||
"int_option": 64,
|
||||
"int_list_option": [32, 64],
|
||||
"string_option": "string",
|
||||
"float_option": 1.0,
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class GetBackendTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user