Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332

PiperOrigin-RevId: 537047525
This commit is contained in:
Yash Katariya 2023-06-01 09:36:32 -07:00 committed by jax authors
parent adca0fa9b8
commit ae9d1498e5
18 changed files with 110 additions and 264 deletions

View File

@ -62,7 +62,6 @@ from jax._src.api_util import (
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.sharding import Sharding
from jax._src.sharding_impls import PmapSharding
@ -2894,7 +2893,6 @@ def clear_caches():
xc._xla.PjitFunctionCache.clear_all()
# Clear all C++ compiled executable caches for pmap
if xla_extension_version >= 146: # TODO(frostig): remove when ready
for fun in _pmap_cache_clears:
fun._cache_clear()

View File

@ -32,7 +32,6 @@ from jax._src import profiler
from jax._src import xla_bridge
from jax._src.config import config
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
@ -425,7 +424,7 @@ class ArrayImpl(basearray.Array):
def addressable_data(self, index: int) -> ArrayImpl:
self._check_if_deleted()
if self.is_fully_replicated and xla_extension_version >= 148:
if self.is_fully_replicated:
return self._fully_replicated_shard()
return self._arrays[index]

View File

@ -35,7 +35,6 @@ from jax._src.compilation_cache_interface import CacheInterface
from jax._src.gfile_cache import GFileCache
from jax._src.lib import xla_client
from jax._src.lib import version_str as jaxlib_version_str
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import passmanager as pm
@ -200,10 +199,7 @@ def _hash_devices(hash_obj, devices: np.ndarray) -> None:
_hash_string(hash_obj, device.device_kind)
def _hash_compile_options(hash_obj, compile_options_obj):
if xla_extension_version >= 145:
expected_num_compile_options = 12
else:
expected_num_compile_options = 11
# Ignore private and built-in methods. These can unexpectedly change and lead
# to false positives, e.g. when different Python versions include different
# built-ins.
@ -232,7 +228,6 @@ def _hash_compile_options(hash_obj, compile_options_obj):
if compile_options_obj.device_assignment is not None:
hash_obj.update(compile_options_obj.device_assignment.serialize())
_hash_bool(hash_obj, compile_options_obj.compile_portable_executable)
if xla_extension_version >= 145:
_hash_int(hash_obj, len(compile_options_obj.env_option_overrides))
for kv in compile_options_obj.env_option_overrides:
_hash_string(hash_obj, kv[0])

View File

@ -36,7 +36,6 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding
@ -355,17 +354,6 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
op_sharding = hlo_sharding.to_proto()
return _op_sharding_callback(op_sharding)
# Here we store information in a container that we store globally so the
# custom partitioning code can access it.
sharding_callback_info = ShardingCallbackInfo(_hlo_sharding_callback,
ctx.module_context)
if xla_extension_version < 150:
key = str(id(sharding_callback_info))
sharding_callbacks[key] = sharding_callback_info
# We need to make sure `sharding_callback_info` is still alive when the SPMD
# partitioner runs so we keep it alive by attaching it to the executable.
ctx.module_context.add_keepalive(sharding_callback_info)
else:
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)
# We need to make sure `_hlo_sharding_callback` is still alive when the SPMD
# partitioner runs so we keep it alive by attaching it to the executable. #
@ -420,15 +408,6 @@ def inspect_sharding_infer_sharding_from_operands(arg_shapes, arg_shardings,
del arg_shapes, shape, backend_string
return arg_shardings[0]
if xla_extension_version < 150:
xc.register_custom_call_partitioner( # pytype: disable=module-attr
_INSPECT_SHARDING_CALL_NAME,
inspect_sharding_prop_user_sharding,
inspect_sharding_partition,
inspect_sharding_infer_sharding_from_operands,
True,
)
def _slice_to_chunk_idx(size: int, slc: slice) -> int:
if slc.stop == slc.start == None:
return 0

View File

@ -441,18 +441,9 @@ approx_top_k_p = core.Primitive('approx_top_k')
approx_top_k_p.multiple_results = True
approx_top_k_p.def_impl(partial(dispatch.apply_primitive, approx_top_k_p))
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
if xc.mlir_api_version > 48:
mlir.register_lowering(approx_top_k_p,
mlir.register_lowering(approx_top_k_p,
partial(_approx_top_k_lowering, fallback=True))
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
platform='tpu')
elif xc.mlir_api_version == 48:
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
platform='tpu')
else:
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
mlir.register_lowering(approx_top_k_p, _approx_top_k_lowering,
platform='tpu')
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp

View File

@ -46,7 +46,6 @@ from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
@ -631,20 +630,13 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
result_shapes = None
op = mlir.custom_call(
"Eigh",
(result_types if xla_extension_version >= 150 else
[ir.TupleType.get_tuple(result_types)]),
result_types,
[operand],
backend_config=backend_config,
api_version=1,
result_shapes=result_shapes,
)
if xla_extension_version >= 150:
return op.results[1], op.results[0]
else:
return (
hlo.GetTupleElementOp(op, 1).result,
hlo.GetTupleElementOp(op, 0).result,
)
eigh_jacobi_p = Primitive('eigh_jacobi')
eigh_jacobi_p.multiple_results = True
@ -1269,20 +1261,12 @@ def _lu_tpu_lowering_rule(ctx, operand):
mlir.aval_to_ir_type(ctx.avals_out[2])
]
op = hlo.CustomCallOp(
(result_types if xla_extension_version >= 150 else
[ir.TupleType.get(result_types)]),
result_types,
[operand],
call_target_name=ir.StringAttr.get("LuDecomposition"),
has_side_effect=ir.BoolAttr.get(False),
)
if xla_extension_version >= 150:
return op.results
else:
return (
hlo.GetTupleElementOp(op, 0).result,
hlo.GetTupleElementOp(op, 1).result,
hlo.GetTupleElementOp(op, 2).result,
)
lu_p = Primitive('lu')
@ -1411,19 +1395,12 @@ def _geqrf_lowering_rule(ctx, operand):
result_shapes = None
op = mlir.custom_call(
"Qr",
(result_types if xla_extension_version >= 150
else [ir.TupleType.get(result_types)]),
result_types,
[operand],
api_version=1,
result_shapes=result_shapes
)
if xla_extension_version >= 150:
return op.results
else:
return (
hlo.GetTupleElementOp(op, 0).result,
hlo.GetTupleElementOp(op, 1).result,
)
def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):

View File

@ -19,7 +19,6 @@ from typing import List, Sequence, Tuple, Union
import numpy as np
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
def get_num_ways_dim_sharded(
@ -49,11 +48,7 @@ def get_num_ways_dim_sharded(
def is_op_sharding_replicated(op: Union[xc.OpSharding, xc.HloSharding]) -> bool:
if isinstance(op, xc.OpSharding):
if xla_extension_version >= 147:
return xc._xla.is_op_sharding_fully_replicated(op)
if len(op.tile_assignment_devices) == 1:
return True
return xc.HloSharding.from_proto(op).is_replicated() # type: ignore
else:
assert isinstance(op, xc.HloSharding)
if op.num_devices() == 1:

View File

@ -39,7 +39,6 @@ from jax._src import op_shardings
from jax._src import util
from jax._src.lib import pmap_lib
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
unsafe_map, map = map, util.safe_map
@ -106,7 +105,7 @@ def sharding_spec_sharding_proto(
else:
util.assert_unreachable(assignment)
if xla_extension_version >= 157 and hlo_sharding:
if hlo_sharding:
if len(replicated_maxes) == len(self.mesh_mapping) and not special_axes:
return xc.HloSharding.replicate()
else:
@ -153,7 +152,7 @@ def sharding_spec_sharding_proto(
proto_mesh = mesh.transpose(mesh_permutation).reshape(new_mesh_shape)
if xla_extension_version >= 157 and hlo_sharding:
if hlo_sharding:
return xc.HloSharding.iota_tile(
dims=new_mesh_shape, reshape_dims=list(proto_mesh.shape),
transpose_perm=mesh_permutation, subgroup_types=last_tile_dims)

View File

@ -31,7 +31,6 @@ import sys
import threading
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import warnings
from jax._src.lib import xla_extension_version
import numpy as np
from jax._src import lib
@ -164,11 +163,7 @@ def get_compile_options(
compile_options.device_assignment = device_assignment
if env_options_overrides is not None:
if xla_extension_version >= 145:
compile_options.env_option_overrides = list(env_options_overrides.items())
else:
raise TypeError(
"`env_options_overrides` is only supported in later versions of jaxlib")
debug_options = compile_options.executable_build_options.debug_options
if lib.cuda_path is not None:
@ -410,7 +405,6 @@ def register_plugin(
options: Optional. It is used when creating a PJRT plugin client.
"""
def factory():
if xla_extension_version >= 152:
# Plugin may already be statically linked in some configurations.
if not xla_client.pjrt_plugin_loaded(plugin_name):
if library_path is None:
@ -419,13 +413,6 @@ def register_plugin(
' plugin.'
)
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
else:
if library_path is None:
raise ValueError(
'The library path is None when trying to dynamically load the'
' plugin.'
)
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
return xla_client.make_c_api_client(plugin_name, options)
logger.debug(
@ -810,11 +797,9 @@ def using_pjrt_c_api(backend=None):
# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
def make_pjrt_tpu_topology(topology_name=None, **kwargs):
if xla_extension_version >= 153:
# TODO(b/261484192): Make a system for lazily loading libtpu.so and call
# that inside make_tfrt_tpu_c_api_device_topology.
get_backend() # Properly initialize libtpu.so.
return xla_client.make_tfrt_tpu_c_api_device_topology(
topology_name, **kwargs
)
raise NotImplementedError('make_pjrt_tpu_topology is not implemented')

View File

@ -27,7 +27,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src import custom_api_util
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.api_util import argnums_partial
@ -496,8 +495,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
result_types = [mlir.aval_to_ir_type(s) for s in call.out_avals]
out = hlo.CustomCallOp(
(result_types if xla_extension_version >= 150 or len(result_types) == 1
else [ir.TupleType.get(result_types)]),
result_types,
list(values),
call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME),
has_side_effect=ir.BoolAttr.get(False),
@ -506,13 +504,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
backend_config=ir.StringAttr.get(key),
operand_layouts=None,
result_layouts=None)
if xla_extension_version >= 150 or len(result_types) == 1:
return out.results
else:
return [
hlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
for i in range(len(result_types))
]
mlir.register_lowering(custom_partitioning_p,
_custom_partitioning_lowering_rule)

View File

@ -16,7 +16,7 @@
# eval()-ed by setup.py, so it should not have any dependencies.
__version__ = "0.4.12"
_minimum_jaxlib_version = "0.4.7"
_minimum_jaxlib_version = "0.4.11"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

View File

@ -26,7 +26,6 @@ from jax._src.config import flags
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
serialize, deserialize_and_load)
from jax._src.lib import xla_extension_version
from jax.experimental import topologies
from jax.sharding import PartitionSpec as P
@ -40,9 +39,6 @@ with contextlib.suppress(ImportError):
pytestmark = pytest.mark.multiaccelerator
@unittest.skipIf(
xla_extension_version < 151, 'Test requires xla_extension_version >= 151'
)
class JaxAotTest(jtu.JaxTestCase):
@jtu.skip_on_devices('cpu', 'gpu')

View File

@ -68,7 +68,6 @@ from jax._src import prng
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src import test_util as jtu
from jax import tree_util
from jax._src import linear_util as lu
@ -1177,8 +1176,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
@ -1188,8 +1185,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
lowered.compile( # doesn't crash
compiler_options={"xla_embed_ir_in_executable": True})
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_invalid(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
@ -1207,8 +1202,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
lambda: lowered.compile(
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_multiple(self):
def f(x):
return jnp.sqrt(x ** 2) + 1.
@ -4231,8 +4224,6 @@ class APITest(jtu.JaxTestCase):
out = jax.grad(f)(3.0) # doesn't crash
self.assertAllClose(out, 1., check_dtypes=False)
@unittest.skipIf(xla_extension_version < 146,
'Test requires xla_extension_version >= 146')
def test_cache_clear_pmap(self):
@jax.pmap
def f(i):

View File

@ -30,7 +30,6 @@ from jax._src import op_shardings
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.sharding_impls import _from_op_sharding_to_pos_sharding
from jax.experimental.pjit import pjit
@ -694,7 +693,6 @@ class JaxArrayTest(jtu.JaxTestCase):
dtype=jtu.dtypes.all,
shape=[(), (10), (2, 3)],
)
@unittest.skipIf(xla_extension_version < 158, "Test requires jaxlib >= 0.4.11")
def test_buffer_protocol(self, dtype, shape):
if jtu.device_under_test() != "cpu":
raise unittest.SkipTest("Buffer protocol only works on CPU")
@ -716,7 +714,6 @@ class JaxArrayTest(jtu.JaxTestCase):
y_bytes = memoryview(y).tobytes()
self.assertEqual(x_bytes, y_bytes)
@unittest.skipIf(xla_extension_version < 157, "Test requires jaxlib >= 0.4.11")
def test_buffer_protocol_deletion(self):
if jtu.device_under_test() != "cpu":
raise unittest.SkipTest("Buffer protocol only works on CPU")
@ -742,8 +739,6 @@ class JaxArrayTest(jtu.JaxTestCase):
self.assertArraysEqual(np.arange(8.), x)
def test_array_fully_replicated_shard(self):
if xla_extension_version < 148:
self.skipTest('Requires xla_extension_version >= 148')
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)

View File

@ -28,7 +28,6 @@ from jax._src import debugging
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
import jax.numpy as jnp
import numpy as np
@ -1150,8 +1149,6 @@ class InspectShardingTest(jtu.JaxTestCase):
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Inspect sharding is not supported on libtpu.")
if xla_bridge.using_pjrt_c_api() and xla_extension_version < 150:
raise unittest.SkipTest("Inspect sharding is not supported on Cloud TPU")
is_called = False
def _cb(sd):

View File

@ -56,7 +56,6 @@ from jax.interpreters import mlir
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2, safe_zip
from jax import config
@ -1164,9 +1163,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
def test_custom_partitioner(self):
self.skip_if_custom_partitioning_not_supported()
if xla_extension_version < 154:
self.skipTest('Requires xla_extension_version >= 154')
def partition(precision, arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
result_sharding = result_shape[0].sharding
@ -1285,9 +1281,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_custom_partitioner_invalid_sharding(self):
self.skip_if_custom_partitioning_not_supported()
if xla_extension_version < 149:
self.skipTest('Requires xla_extension_version >= 149')
def partition(arg_shapes, result_shape):
def lower_fn(x):
return x
@ -3888,7 +3881,6 @@ class UtilTest(jtu.JaxTestCase):
hs2 = xc.HloSharding.from_proto(op2)
hs3 = xc.HloSharding.from_proto(op3)
if xla_extension_version >= 156:
self.assertEqual(hs1, xc.HloSharding.iota_tile((2, 2)))
self.assertEqual(hs2, xc.HloSharding.iota_tile((2, 2)))
self.assertEqual(hs3, xc.HloSharding.iota_tile((4, 2)))
@ -3919,7 +3911,6 @@ class UtilTest(jtu.JaxTestCase):
hs1 = xc.HloSharding.from_proto(op1)
hs2 = xc.HloSharding.from_proto(op2)
if xla_extension_version >= 156:
self.assertEqual(
hs1,
xc.HloSharding.iota_tile(
@ -3972,8 +3963,6 @@ class UtilTest(jtu.JaxTestCase):
self.assertNotEqual(hash(hs1), hash(hs2))
def test_hlo_sharding_iota_tile_error(self):
if xla_extension_version < 156:
self.skipTest('Requires xla_extension_version >= 156')
self.assertRaisesRegex(
xla_extension.XlaRuntimeError,
'INVALID_ARGUMENT: `dims` should not be empty.',
@ -4091,7 +4080,6 @@ class UtilTest(jtu.JaxTestCase):
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op3))
if xla_extension_version >= 156:
hs1 = xc.HloSharding.from_proto(op1)
self.assertEqual(
hs1,
@ -4124,9 +4112,6 @@ class UtilTest(jtu.JaxTestCase):
)
def test_hlo_sharding_manual_replicated(self):
if xla_extension_version < 156:
self.skipTest('Requires xla_extension_version >= 156')
hs1 = xc.HloSharding.manual()
self.assertTrue(hs1.is_manual())
self.assertFalse(hs1.tile_assignment_devices())

View File

@ -47,7 +47,6 @@ from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import xla_bridge
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import safe_map, safe_zip
from jax._src.interpreters import pxla
from jax.interpreters import xla
@ -332,8 +331,6 @@ class PythonPmapTest(jtu.JaxTestCase):
f = f.lower(x).compile()
self.assertIsNotNone(f.runtime_executable())
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
@ -343,8 +340,6 @@ class PythonPmapTest(jtu.JaxTestCase):
lowered.compile( # doesn't crash
compiler_options={"xla_embed_ir_in_executable": True})
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_invalid(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
@ -361,8 +356,6 @@ class PythonPmapTest(jtu.JaxTestCase):
lambda: lowered.compile(
compiler_options={"xla_embed_ir_in_executable": "invalid_value"}))
@unittest.skipIf(xla_extension_version < 145,
'Test requires xla_extension_version >= 145')
def test_jit_lower_compile_with_compiler_options_multiple(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)

View File

@ -21,7 +21,6 @@ from absl.testing import absltest
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import xla
from jax._src.config import config
@ -97,15 +96,9 @@ class XlaBridgeTest(jtu.JaxTestCase):
xb.register_pjrt_plugin_factories_from_env()
client_factory, priotiy = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
if xla_extension_version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
client_factory()
else:
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
client_factory()
self.assertRegex(
@ -116,10 +109,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
self.assertIn("name1", xb._backend_factories)
self.assertIn("name2", xb._backend_factories)
self.assertEqual(priotiy, 400)
if xla_extension_version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with("name1", "path1")
mock_make.assert_called_once_with("name1", None)
def test_register_plugin_with_config(self):
@ -130,25 +120,14 @@ class XlaBridgeTest(jtu.JaxTestCase):
xb.register_pjrt_plugin_factories_from_env()
client_factory, priority = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
if xla_extension_version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
client_factory()
else:
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
client_factory()
self.assertIn("name1", xb._backend_factories)
self.assertEqual(priority, 400)
if xla_extension_version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with(
"name1", "/path/pjrt_plugin_name1.so"
)
mock_make.assert_called_once_with(
"name1",
{