Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of xc._version and replace it with xla_extension_version.

PiperOrigin-RevId: 496474494
This commit is contained in:
Yash Katariya 2022-12-19 13:13:15 -08:00 committed by jax authors
parent 4301a85d46
commit dbc39449b7
5 changed files with 19 additions and 42 deletions

View File

@ -511,7 +511,6 @@ def _cpp_jit_clear_cache(self):
def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
use_fastpath = (
xc._version >= 96 and
# This is if we have already executed this code-path (most-recent entry
# has been reset to None). Thus, we do not support the fast-path.
execute is not None and
@ -2286,8 +2285,7 @@ def _cpp_pmap(
# No tracers in the outputs. Checking for ShardedDeviceArray should be
# sufficient, but we use the more general `DeviceArray`.
all(
isinstance(x, device_array.DeviceArray) or
xc._version >= 96 and isinstance(x, xc.ArrayImpl)
isinstance(x, device_array.DeviceArray) or isinstance(x, xc.ArrayImpl)
for x in out_flat))
### If we can use the fastpath, we return required info to the caller.
@ -3453,8 +3451,4 @@ def live_arrays(platform=None):
If platform is None, it is the default backend.
"""
if xc._version >= 102:
return xb.get_backend(platform).live_arrays()
raise RuntimeError(
"live_arrays() is not supported yet. Please update your jaxlib package.")
return xb.get_backend(platform).live_arrays()

View File

@ -99,7 +99,7 @@ def _single_device_array_from_buf(buf, committed):
committed=committed, _skip_checks=True)
@use_cpp_class(xc.ArrayImpl if xc._version >= 99 else None)
@use_cpp_class(xc.ArrayImpl)
class ArrayImpl(basearray.Array):
# TODO(yashkatariya): Add __slots__ here.
@ -612,9 +612,8 @@ core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval')
if xc._version >= 96:
# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(ArrayImpl)
# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(ArrayImpl)
def _array_mlir_constant_handler(val, canonicalize_types=True):

View File

@ -52,6 +52,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
treedef_is_leaf, tree_structure, treedef_tuple)
from jax._src.tree_util import prefix_errors
@ -161,12 +162,9 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
return outs, fastpath_data
if xc._version < 108:
cpp_pjit_f = xc._xla.pjit(fun, cache_miss, static_argnums) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type:ignore
cache_miss, static_argnums)
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type:ignore
cache_miss, static_argnums)
return wraps(fun)(cpp_pjit_f)
@ -479,7 +477,7 @@ def pjit(
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)
if FLAGS.experimental_cpp_pjit and xc._version >= 111:
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 111:
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
else:
wrapped = _python_pjit(fun, infer_params)

View File

@ -33,7 +33,7 @@ Device = xc.Device
Index = Tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]
@use_cpp_class(xc.Sharding if xc._version >= 94 else None)
@use_cpp_class(xc.Sharding)
class Sharding(metaclass=abc.ABCMeta):
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
across devices.
@ -99,7 +99,7 @@ class Sharding(metaclass=abc.ABCMeta):
# Shardings that inherit from XLACompatibleSharding should implement the
# `_device_assignment` property and `_to_xla_op_sharding` method.
@use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None)
@use_cpp_class(xc.XLACompatibleSharding)
class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
"""A `Sharding` that describes shardings expressible to XLA.
@ -190,15 +190,6 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out
def _enable_cpp_named_sharding():
if xc._version >= 107:
return xc.NamedSharding
elif xc._version >= 95:
return xc.MeshPspecSharding # type: ignore
else:
return None
class _UnconstrainedPartitionSingleton:
def __str__(self):
@ -235,7 +226,7 @@ class PartitionSpec(tuple):
return (PartitionSpec, tuple(self))
@use_cpp_class(_enable_cpp_named_sharding())
@use_cpp_class(xc.NamedSharding)
class NamedSharding(XLACompatibleSharding):
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
@ -367,7 +358,7 @@ def _get_replicated_op_sharding():
return proto
@use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
@use_cpp_class(xc.SingleDeviceSharding)
class SingleDeviceSharding(XLACompatibleSharding):
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
@ -415,7 +406,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
return _get_replicated_op_sharding()
@use_cpp_class(xc.PmapSharding if xc._version >= 94 else None)
@use_cpp_class(xc.PmapSharding)
class PmapSharding(XLACompatibleSharding):
@use_cpp_method
@ -624,7 +615,7 @@ class DeviceIdSet:
self._ids == other._ids)
@use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
@use_cpp_class(xc.OpShardingSharding)
class OpShardingSharding(XLACompatibleSharding):
@use_cpp_method

View File

@ -74,6 +74,7 @@ from jax._src.config import flags
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
@ -3563,7 +3564,7 @@ class MeshExecutable(stages.XlaExecutable):
not self.unsafe_call.has_host_callbacks):
return None
if not flags.FLAGS.experimental_cpp_pjit or xc._version < 111:
if not flags.FLAGS.experimental_cpp_pjit or xla_extension_version < 111:
return None
def aot_cache_miss(*args, **kwargs):
@ -3583,13 +3584,7 @@ class MeshExecutable(stages.XlaExecutable):
fastpath_data = None
return outs, fastpath_data
if xc._version < 108:
def dummy():
pass
dummy.__name__ = self.unsafe_call.name
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
else:
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
def _out_shardings_for_trivial(