mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
This commit is contained in:
parent
adca0fa9b8
commit
ae9d1498e5
@ -62,7 +62,6 @@ from jax._src.api_util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import PmapSharding
|
||||
@ -2894,7 +2893,6 @@ def clear_caches():
|
||||
xc._xla.PjitFunctionCache.clear_all()
|
||||
|
||||
# Clear all C++ compiled executable caches for pmap
|
||||
if xla_extension_version >= 146: # TODO(frostig): remove when ready
|
||||
for fun in _pmap_cache_clears:
|
||||
fun._cache_clear()
|
||||
|
||||
|
@ -32,7 +32,6 @@ from jax._src import profiler
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
@ -425,7 +424,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
def addressable_data(self, index: int) -> ArrayImpl:
|
||||
self._check_if_deleted()
|
||||
if self.is_fully_replicated and xla_extension_version >= 148:
|
||||
if self.is_fully_replicated:
|
||||
return self._fully_replicated_shard()
|
||||
return self._arrays[index]
|
||||
|
||||
|
@ -35,7 +35,6 @@ from jax._src.compilation_cache_interface import CacheInterface
|
||||
from jax._src.gfile_cache import GFileCache
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import version_str as jaxlib_version_str
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir import passmanager as pm
|
||||
|
||||
@ -200,10 +199,7 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None:
|
||||
_hash_string(hash_obj, device.device_kind)
|
||||
|
||||
def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if xla_extension_version >= 145:
|
||||
expected_num_compile_options = 12
|
||||
else:
|
||||
expected_num_compile_options = 11
|
||||
# Ignore private and built-in methods. These can unexpectedly change and lead
|
||||
# to false positives, e.g. when different Python versions include different
|
||||
# built-ins.
|
||||
@ -232,7 +228,6 @@ def _hash_compile_options(hash_obj, compile_options_obj):
|
||||
if compile_options_obj.device_assignment is not None:
|
||||
hash_obj.update(compile_options_obj.device_assignment.serialize())
|
||||
_hash_bool(hash_obj, compile_options_obj.compile_portable_executable)
|
||||
if xla_extension_version >= 145:
|
||||
_hash_int(hash_obj, len(compile_options_obj.env_option_overrides))
|
||||
for kv in compile_options_obj.env_option_overrides:
|
||||
_hash_string(hash_obj, kv[0])
|
||||
|
@ -36,7 +36,6 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding
|
||||
@ -355,17 +354,6 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
op_sharding = hlo_sharding.to_proto()
|
||||
return _op_sharding_callback(op_sharding)
|
||||
|
||||
# Here we store information in a container that we store globally so the
|
||||
# custom partitioning code can access it.
|
||||
sharding_callback_info = ShardingCallbackInfo(_hlo_sharding_callback,
|
||||
ctx.module_context)
|
||||
if xla_extension_version < 150:
|
||||
key = str(id(sharding_callback_info))
|
||||
sharding_callbacks[key] = sharding_callback_info
|
||||
# We need to make sure `sharding_callback_info` is still alive when the SPMD
|
||||
# partitioner runs so we keep it alive by attaching it to the executable.
|
||||
ctx.module_context.add_keepalive(sharding_callback_info)
|
||||
else:
|
||||
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)
|
||||
# We need to make sure `_hlo_sharding_callback` is still alive when the SPMD
|
||||
# partitioner runs so we keep it alive by attaching it to the executable. #
|
||||
@ -420,15 +408,6 @@ def inspect_sharding_infer_sharding_from_operands(arg_shapes, arg_shardings,
|
||||
del arg_shapes, shape, backend_string
|
||||
return arg_shardings[0]
|
||||
|
||||
if xla_extension_version < 150:
|
||||
xc.register_custom_call_partitioner( # pytype: disable=module-attr
|
||||
_INSPECT_SHARDING_CALL_NAME,
|
||||
inspect_sharding_prop_user_sharding,
|
||||
inspect_sharding_partition,
|
||||
inspect_sharding_infer_sharding_from_operands,
|
||||
True,
|
||||
)
|
||||
|
||||
def _slice_to_chunk_idx(size: int, slc: slice) -> int:
|
||||
if slc.stop == slc.start == None:
|
||||
return 0
|
||||
|
@ -441,18 +441,9 @@ approx_top_k_p = core.Primitive('approx_top_k')
|
||||
approx_top_k_p.multiple_results = True
|
||||
approx_top_k_p.def_impl(partial(dispatch.apply_primitive, approx_top_k_p))
|
||||
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
|
||||
if xc.mlir_api_version > 48:
|
||||
mlir.register_lowering(approx_top_k_p,
|
||||
mlir.register_lowering(approx_top_k_p,
|
||||
partial(_approx_top_k_lowering, fallback=True))
|
||||
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
|
||||
platform='tpu')
|
||||
elif xc.mlir_api_version == 48:
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
|
||||
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
|
||||
platform='tpu')
|
||||
else:
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
|
||||
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
|
||||
platform='tpu')
|
||||
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
|
||||
ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp
|
||||
|
@ -46,7 +46,6 @@ from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -631,20 +630,13 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
|
||||
result_shapes = None
|
||||
op = mlir.custom_call(
|
||||
"Eigh",
|
||||
(result_types if xla_extension_version >= 150 else
|
||||
[ir.TupleType.get_tuple(result_types)]),
|
||||
result_types,
|
||||
[operand],
|
||||
backend_config=backend_config,
|
||||
api_version=1,
|
||||
result_shapes=result_shapes,
|
||||
)
|
||||
if xla_extension_version >= 150:
|
||||
return op.results[1], op.results[0]
|
||||
else:
|
||||
return (
|
||||
hlo.GetTupleElementOp(op, 1).result,
|
||||
hlo.GetTupleElementOp(op, 0).result,
|
||||
)
|
||||
|
||||
eigh_jacobi_p = Primitive('eigh_jacobi')
|
||||
eigh_jacobi_p.multiple_results = True
|
||||
@ -1269,20 +1261,12 @@ def _lu_tpu_lowering_rule(ctx, operand):
|
||||
mlir.aval_to_ir_type(ctx.avals_out[2])
|
||||
]
|
||||
op = hlo.CustomCallOp(
|
||||
(result_types if xla_extension_version >= 150 else
|
||||
[ir.TupleType.get(result_types)]),
|
||||
result_types,
|
||||
[operand],
|
||||
call_target_name=ir.StringAttr.get("LuDecomposition"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
)
|
||||
if xla_extension_version >= 150:
|
||||
return op.results
|
||||
else:
|
||||
return (
|
||||
hlo.GetTupleElementOp(op, 0).result,
|
||||
hlo.GetTupleElementOp(op, 1).result,
|
||||
hlo.GetTupleElementOp(op, 2).result,
|
||||
)
|
||||
|
||||
|
||||
lu_p = Primitive('lu')
|
||||
@ -1411,19 +1395,12 @@ def _geqrf_lowering_rule(ctx, operand):
|
||||
result_shapes = None
|
||||
op = mlir.custom_call(
|
||||
"Qr",
|
||||
(result_types if xla_extension_version >= 150
|
||||
else [ir.TupleType.get(result_types)]),
|
||||
result_types,
|
||||
[operand],
|
||||
api_version=1,
|
||||
result_shapes=result_shapes
|
||||
)
|
||||
if xla_extension_version >= 150:
|
||||
return op.results
|
||||
else:
|
||||
return (
|
||||
hlo.GetTupleElementOp(op, 0).result,
|
||||
hlo.GetTupleElementOp(op, 1).result,
|
||||
)
|
||||
|
||||
def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
|
@ -19,7 +19,6 @@ from typing import List, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
def get_num_ways_dim_sharded(
|
||||
@ -49,11 +48,7 @@ def get_num_ways_dim_sharded(
|
||||
|
||||
def is_op_sharding_replicated(op: Union[xc.OpSharding, xc.HloSharding]) -> bool:
|
||||
if isinstance(op, xc.OpSharding):
|
||||
if xla_extension_version >= 147:
|
||||
return xc._xla.is_op_sharding_fully_replicated(op)
|
||||
if len(op.tile_assignment_devices) == 1:
|
||||
return True
|
||||
return xc.HloSharding.from_proto(op).is_replicated() # type: ignore
|
||||
else:
|
||||
assert isinstance(op, xc.HloSharding)
|
||||
if op.num_devices() == 1:
|
||||
|
@ -39,7 +39,6 @@ from jax._src import op_shardings
|
||||
from jax._src import util
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
unsafe_map, map = map, util.safe_map
|
||||
|
||||
@ -106,7 +105,7 @@ def sharding_spec_sharding_proto(
|
||||
else:
|
||||
util.assert_unreachable(assignment)
|
||||
|
||||
if xla_extension_version >= 157 and hlo_sharding:
|
||||
if hlo_sharding:
|
||||
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
|
||||
return xc.HloSharding.replicate()
|
||||
else:
|
||||
@ -153,7 +152,7 @@ def sharding_spec_sharding_proto(
|
||||
|
||||
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
|
||||
|
||||
if xla_extension_version >= 157 and hlo_sharding:
|
||||
if hlo_sharding:
|
||||
return xc.HloSharding.iota_tile(
|
||||
dims=new_mesh_shape, reshape_dims=list(proto_mesh.shape),
|
||||
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)
|
||||
|
@ -31,7 +31,6 @@ import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
import warnings
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lib
|
||||
@ -164,11 +163,7 @@ def get_compile_options(
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
if env_options_overrides is not None:
|
||||
if xla_extension_version >= 145:
|
||||
compile_options.env_option_overrides = list(env_options_overrides.items())
|
||||
else:
|
||||
raise TypeError(
|
||||
"`env_options_overrides` is only supported in later versions of jaxlib")
|
||||
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
if lib.cuda_path is not None:
|
||||
@ -410,7 +405,6 @@ def register_plugin(
|
||||
options: Optional. It is used when creating a PJRT plugin client.
|
||||
"""
|
||||
def factory():
|
||||
if xla_extension_version >= 152:
|
||||
# Plugin may already be statically linked in some configurations.
|
||||
if not xla_client.pjrt_plugin_loaded(plugin_name):
|
||||
if library_path is None:
|
||||
@ -419,13 +413,6 @@ def register_plugin(
|
||||
' plugin.'
|
||||
)
|
||||
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
else:
|
||||
if library_path is None:
|
||||
raise ValueError(
|
||||
'The library path is None when trying to dynamically load the'
|
||||
' plugin.'
|
||||
)
|
||||
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
return xla_client.make_c_api_client(plugin_name, options)
|
||||
|
||||
logger.debug(
|
||||
@ -810,11 +797,9 @@ def using_pjrt_c_api(backend=None):
|
||||
|
||||
# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
|
||||
def make_pjrt_tpu_topology(topology_name=None, **kwargs):
|
||||
if xla_extension_version >= 153:
|
||||
# TODO(b/261484192): Make a system for lazily loading libtpu.so and call
|
||||
# that inside make_tfrt_tpu_c_api_device_topology.
|
||||
get_backend() # Properly initialize libtpu.so.
|
||||
return xla_client.make_tfrt_tpu_c_api_device_topology(
|
||||
topology_name, **kwargs
|
||||
)
|
||||
raise NotImplementedError('make_pjrt_tpu_topology is not implemented')
|
||||
|
@ -27,7 +27,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import custom_api_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.api_util import argnums_partial
|
||||
|
||||
@ -496,8 +495,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
|
||||
result_types = [mlir.aval_to_ir_type(s) for s in call.out_avals]
|
||||
out = hlo.CustomCallOp(
|
||||
(result_types if xla_extension_version >= 150 or len(result_types) == 1
|
||||
else [ir.TupleType.get(result_types)]),
|
||||
result_types,
|
||||
list(values),
|
||||
call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
@ -506,13 +504,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
backend_config=ir.StringAttr.get(key),
|
||||
operand_layouts=None,
|
||||
result_layouts=None)
|
||||
if xla_extension_version >= 150 or len(result_types) == 1:
|
||||
return out.results
|
||||
else:
|
||||
return [
|
||||
hlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
|
||||
for i in range(len(result_types))
|
||||
]
|
||||
|
||||
mlir.register_lowering(custom_partitioning_p,
|
||||
_custom_partitioning_lowering_rule)
|
||||
|
@ -16,7 +16,7 @@
|
||||
# eval()-ed by setup.py, so it should not have any dependencies.
|
||||
|
||||
__version__ = "0.4.12"
|
||||
_minimum_jaxlib_version = "0.4.7"
|
||||
_minimum_jaxlib_version = "0.4.11"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
@ -26,7 +26,6 @@ from jax._src.config import flags
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
serialize, deserialize_and_load)
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import topologies
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
@ -40,9 +39,6 @@ with contextlib.suppress(ImportError):
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
xla_extension_version < 151, 'Test requires xla_extension_version >= 151'
|
||||
)
|
||||
class JaxAotTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices('cpu', 'gpu')
|
||||
|
@ -68,7 +68,6 @@ from jax._src import prng
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
@ -1177,8 +1176,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
@ -1188,8 +1185,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
lowered.compile( # doesn't crash
|
||||
compiler_options={"xla_embed_ir_in_executable": True})
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_invalid(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
@ -1207,8 +1202,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_multiple(self):
|
||||
def f(x):
|
||||
return jnp.sqrt(x ** 2) + 1.
|
||||
@ -4231,8 +4224,6 @@ class APITest(jtu.JaxTestCase):
|
||||
out = jax.grad(f)(3.0) # doesn't crash
|
||||
self.assertAllClose(out, 1., check_dtypes=False)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 146,
|
||||
'Test requires xla_extension_version >= 146')
|
||||
def test_cache_clear_pmap(self):
|
||||
@jax.pmap
|
||||
def f(i):
|
||||
|
@ -30,7 +30,6 @@ from jax._src import op_shardings
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.sharding_impls import _from_op_sharding_to_pos_sharding
|
||||
from jax.experimental.pjit import pjit
|
||||
@ -694,7 +693,6 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
dtype=jtu.dtypes.all,
|
||||
shape=[(), (10), (2, 3)],
|
||||
)
|
||||
@unittest.skipIf(xla_extension_version < 158, "Test requires jaxlib >= 0.4.11")
|
||||
def test_buffer_protocol(self, dtype, shape):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("Buffer protocol only works on CPU")
|
||||
@ -716,7 +714,6 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
y_bytes = memoryview(y).tobytes()
|
||||
self.assertEqual(x_bytes, y_bytes)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11")
|
||||
def test_buffer_protocol_deletion(self):
|
||||
if jtu.device_under_test() != "cpu":
|
||||
raise unittest.SkipTest("Buffer protocol only works on CPU")
|
||||
@ -742,8 +739,6 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(np.arange(8.), x)
|
||||
|
||||
def test_array_fully_replicated_shard(self):
|
||||
if xla_extension_version < 148:
|
||||
self.skipTest('Requires xla_extension_version >= 148')
|
||||
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp_shape = (8, 2)
|
||||
|
@ -28,7 +28,6 @@ from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -1150,8 +1149,6 @@ class InspectShardingTest(jtu.JaxTestCase):
|
||||
|
||||
if jtu.is_cloud_tpu():
|
||||
raise unittest.SkipTest("Inspect sharding is not supported on libtpu.")
|
||||
if xla_bridge.using_pjrt_c_api() and xla_extension_version < 150:
|
||||
raise unittest.SkipTest("Inspect sharding is not supported on Cloud TPU")
|
||||
|
||||
is_called = False
|
||||
def _cb(sd):
|
||||
|
@ -56,7 +56,6 @@ from jax.interpreters import mlir
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2, safe_zip
|
||||
|
||||
from jax import config
|
||||
@ -1164,9 +1163,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
def test_custom_partitioner(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
|
||||
if xla_extension_version < 154:
|
||||
self.skipTest('Requires xla_extension_version >= 154')
|
||||
|
||||
def partition(precision, arg_shapes, result_shape):
|
||||
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
|
||||
result_sharding = result_shape[0].sharding
|
||||
@ -1285,9 +1281,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_custom_partitioner_invalid_sharding(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
if xla_extension_version < 149:
|
||||
self.skipTest('Requires xla_extension_version >= 149')
|
||||
|
||||
def partition(arg_shapes, result_shape):
|
||||
def lower_fn(x):
|
||||
return x
|
||||
@ -3888,7 +3881,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
hs3 = xc.HloSharding.from_proto(op3)
|
||||
|
||||
if xla_extension_version >= 156:
|
||||
self.assertEqual(hs1, xc.HloSharding.iota_tile((2, 2)))
|
||||
self.assertEqual(hs2, xc.HloSharding.iota_tile((2, 2)))
|
||||
self.assertEqual(hs3, xc.HloSharding.iota_tile((4, 2)))
|
||||
@ -3919,7 +3911,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
if xla_extension_version >= 156:
|
||||
self.assertEqual(
|
||||
hs1,
|
||||
xc.HloSharding.iota_tile(
|
||||
@ -3972,8 +3963,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(hash(hs1), hash(hs2))
|
||||
|
||||
def test_hlo_sharding_iota_tile_error(self):
|
||||
if xla_extension_version < 156:
|
||||
self.skipTest('Requires xla_extension_version >= 156')
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError,
|
||||
'INVALID_ARGUMENT: `dims` should not be empty.',
|
||||
@ -4091,7 +4080,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op3))
|
||||
|
||||
if xla_extension_version >= 156:
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
self.assertEqual(
|
||||
hs1,
|
||||
@ -4124,9 +4112,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
def test_hlo_sharding_manual_replicated(self):
|
||||
if xla_extension_version < 156:
|
||||
self.skipTest('Requires xla_extension_version >= 156')
|
||||
|
||||
hs1 = xc.HloSharding.manual()
|
||||
self.assertTrue(hs1.is_manual())
|
||||
self.assertFalse(hs1.tile_assignment_devices())
|
||||
|
@ -47,7 +47,6 @@ from jax._src import sharding_impls
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
@ -332,8 +331,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsNotNone(f.runtime_executable())
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
@ -343,8 +340,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
lowered.compile( # doesn't crash
|
||||
compiler_options={"xla_embed_ir_in_executable": True})
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_invalid(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
@ -361,8 +356,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
lambda: lowered.compile(
|
||||
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 145,
|
||||
'Test requires xla_extension_version >= 145')
|
||||
def test_jit_lower_compile_with_compiler_options_multiple(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
|
@ -21,7 +21,6 @@ from absl.testing import absltest
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src.config import config
|
||||
@ -97,15 +96,9 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
client_factory, priotiy = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
||||
with mock.patch.object(
|
||||
xc, "load_pjrt_plugin_dynamically", autospec=True
|
||||
) as mock_load_plugin:
|
||||
if xla_extension_version >= 152:
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_loaded", autospec=True
|
||||
) as mock_plugin_loaded:
|
||||
client_factory()
|
||||
else:
|
||||
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
|
||||
client_factory()
|
||||
|
||||
self.assertRegex(
|
||||
@ -116,10 +109,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
self.assertIn("name1", xb._backend_factories)
|
||||
self.assertIn("name2", xb._backend_factories)
|
||||
self.assertEqual(priotiy, 400)
|
||||
if xla_extension_version >= 152:
|
||||
mock_plugin_loaded.assert_called_once_with("name1")
|
||||
else:
|
||||
mock_load_plugin.assert_called_once_with("name1", "path1")
|
||||
mock_make.assert_called_once_with("name1", None)
|
||||
|
||||
def test_register_plugin_with_config(self):
|
||||
@ -130,25 +120,14 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
client_factory, priority = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
||||
with mock.patch.object(
|
||||
xc, "load_pjrt_plugin_dynamically", autospec=True
|
||||
) as mock_load_plugin:
|
||||
if xla_extension_version >= 152:
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_loaded", autospec=True
|
||||
) as mock_plugin_loaded:
|
||||
client_factory()
|
||||
else:
|
||||
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
|
||||
client_factory()
|
||||
|
||||
self.assertIn("name1", xb._backend_factories)
|
||||
self.assertEqual(priority, 400)
|
||||
if xla_extension_version >= 152:
|
||||
mock_plugin_loaded.assert_called_once_with("name1")
|
||||
else:
|
||||
mock_load_plugin.assert_called_once_with(
|
||||
"name1", "/path/pjrt_plugin_name1.so"
|
||||
)
|
||||
mock_make.assert_called_once_with(
|
||||
"name1",
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user