mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of xc._version
and replace it with xla_extension_version
.
PiperOrigin-RevId: 496474494
This commit is contained in:
parent
4301a85d46
commit
dbc39449b7
@ -511,7 +511,6 @@ def _cpp_jit_clear_cache(self):
|
||||
|
||||
def _jax_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
|
||||
use_fastpath = (
|
||||
xc._version >= 96 and
|
||||
# This is if we have already executed this code-path (most-recent entry
|
||||
# has been reset to None). Thus, we do not support the fast-path.
|
||||
execute is not None and
|
||||
@ -2286,8 +2285,7 @@ def _cpp_pmap(
|
||||
# No tracers in the outputs. Checking for ShardedDeviceArray should be
|
||||
# sufficient, but we use the more general `DeviceArray`.
|
||||
all(
|
||||
isinstance(x, device_array.DeviceArray) or
|
||||
xc._version >= 96 and isinstance(x, xc.ArrayImpl)
|
||||
isinstance(x, device_array.DeviceArray) or isinstance(x, xc.ArrayImpl)
|
||||
for x in out_flat))
|
||||
|
||||
### If we can use the fastpath, we return required info to the caller.
|
||||
@ -3453,8 +3451,4 @@ def live_arrays(platform=None):
|
||||
|
||||
If platform is None, it is the default backend.
|
||||
"""
|
||||
if xc._version >= 102:
|
||||
return xb.get_backend(platform).live_arrays()
|
||||
|
||||
raise RuntimeError(
|
||||
"live_arrays() is not supported yet. Please update your jaxlib package.")
|
||||
return xb.get_backend(platform).live_arrays()
|
||||
|
@ -99,7 +99,7 @@ def _single_device_array_from_buf(buf, committed):
|
||||
committed=committed, _skip_checks=True)
|
||||
|
||||
|
||||
@use_cpp_class(xc.ArrayImpl if xc._version >= 99 else None)
|
||||
@use_cpp_class(xc.ArrayImpl)
|
||||
class ArrayImpl(basearray.Array):
|
||||
# TODO(yashkatariya): Add __slots__ here.
|
||||
|
||||
@ -612,9 +612,8 @@ core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
|
||||
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
|
||||
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
|
||||
api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval')
|
||||
if xc._version >= 96:
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
basearray.Array.register(ArrayImpl)
|
||||
# TODO(jakevdp) replace this with true inheritance at the C++ level.
|
||||
basearray.Array.register(ArrayImpl)
|
||||
|
||||
|
||||
def _array_mlir_constant_handler(val, canonicalize_types=True):
|
||||
|
@ -52,6 +52,7 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
treedef_is_leaf, tree_structure, treedef_tuple)
|
||||
from jax._src.tree_util import prefix_errors
|
||||
@ -161,12 +162,9 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
|
||||
|
||||
return outs, fastpath_data
|
||||
|
||||
if xc._version < 108:
|
||||
cpp_pjit_f = xc._xla.pjit(fun, cache_miss, static_argnums) # type: ignore
|
||||
else:
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type:ignore
|
||||
cache_miss, static_argnums)
|
||||
cpp_pjit_f = xc._xla.pjit( # type: ignore
|
||||
getattr(fun, "__name__", "<unnamed function>"), # type:ignore
|
||||
cache_miss, static_argnums)
|
||||
|
||||
return wraps(fun)(cpp_pjit_f)
|
||||
|
||||
@ -479,7 +477,7 @@ def pjit(
|
||||
return (args_flat, local_in_avals, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 111:
|
||||
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 111:
|
||||
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
|
||||
else:
|
||||
wrapped = _python_pjit(fun, infer_params)
|
||||
|
@ -33,7 +33,7 @@ Device = xc.Device
|
||||
Index = Tuple[slice, ...]
|
||||
XLADeviceAssignment = Sequence[Device]
|
||||
|
||||
@use_cpp_class(xc.Sharding if xc._version >= 94 else None)
|
||||
@use_cpp_class(xc.Sharding)
|
||||
class Sharding(metaclass=abc.ABCMeta):
|
||||
"""Abstract ``Sharding`` interface which describes how a ``jax.Array`` is laid out
|
||||
across devices.
|
||||
@ -99,7 +99,7 @@ class Sharding(metaclass=abc.ABCMeta):
|
||||
|
||||
# Shardings that inherit from XLACompatibleSharding should implement the
|
||||
# `_device_assignment` property and `_to_xla_op_sharding` method.
|
||||
@use_cpp_class(xc.XLACompatibleSharding if xc._version >= 94 else None)
|
||||
@use_cpp_class(xc.XLACompatibleSharding)
|
||||
class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
|
||||
"""A `Sharding` that describes shardings expressible to XLA.
|
||||
|
||||
@ -190,15 +190,6 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
return out
|
||||
|
||||
|
||||
def _enable_cpp_named_sharding():
|
||||
if xc._version >= 107:
|
||||
return xc.NamedSharding
|
||||
elif xc._version >= 95:
|
||||
return xc.MeshPspecSharding # type: ignore
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class _UnconstrainedPartitionSingleton:
|
||||
|
||||
def __str__(self):
|
||||
@ -235,7 +226,7 @@ class PartitionSpec(tuple):
|
||||
return (PartitionSpec, tuple(self))
|
||||
|
||||
|
||||
@use_cpp_class(_enable_cpp_named_sharding())
|
||||
@use_cpp_class(xc.NamedSharding)
|
||||
class NamedSharding(XLACompatibleSharding):
|
||||
r"""NamedSharding is a way to express ``Sharding``\s using named axes.
|
||||
|
||||
@ -367,7 +358,7 @@ def _get_replicated_op_sharding():
|
||||
return proto
|
||||
|
||||
|
||||
@use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
|
||||
@use_cpp_class(xc.SingleDeviceSharding)
|
||||
class SingleDeviceSharding(XLACompatibleSharding):
|
||||
"""A subclass of ``XLACompatibleSharding`` that places its data on a single device.
|
||||
|
||||
@ -415,7 +406,7 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
return _get_replicated_op_sharding()
|
||||
|
||||
|
||||
@use_cpp_class(xc.PmapSharding if xc._version >= 94 else None)
|
||||
@use_cpp_class(xc.PmapSharding)
|
||||
class PmapSharding(XLACompatibleSharding):
|
||||
|
||||
@use_cpp_method
|
||||
@ -624,7 +615,7 @@ class DeviceIdSet:
|
||||
self._ids == other._ids)
|
||||
|
||||
|
||||
@use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
|
||||
@use_cpp_class(xc.OpShardingSharding)
|
||||
class OpShardingSharding(XLACompatibleSharding):
|
||||
|
||||
@use_cpp_method
|
||||
|
@ -74,6 +74,7 @@ from jax._src.config import flags
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -3563,7 +3564,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
not self.unsafe_call.has_host_callbacks):
|
||||
return None
|
||||
|
||||
if not flags.FLAGS.experimental_cpp_pjit or xc._version < 111:
|
||||
if not flags.FLAGS.experimental_cpp_pjit or xla_extension_version < 111:
|
||||
return None
|
||||
|
||||
def aot_cache_miss(*args, **kwargs):
|
||||
@ -3583,13 +3584,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data
|
||||
|
||||
if xc._version < 108:
|
||||
def dummy():
|
||||
pass
|
||||
dummy.__name__ = self.unsafe_call.name
|
||||
return xc._xla.pjit(dummy, aot_cache_miss, []) # type: ignore
|
||||
else:
|
||||
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
|
||||
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
|
||||
|
||||
|
||||
def _out_shardings_for_trivial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user