mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
This commit is contained in:
parent
7827af9407
commit
07d24e7dcc
@ -40,7 +40,6 @@ from jax._src.interpreters import xla
|
||||
from jax._src.layout import AutoLayout, DeviceLocalLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension as xe
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding,
|
||||
@ -1120,14 +1119,8 @@ def _array_shard_arg(xs, shardings):
|
||||
results.append(
|
||||
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
|
||||
|
||||
if xla_extension_version < 271:
|
||||
copy_outs = [
|
||||
xc.copy_array_to_devices_with_sharding(x, d, s) # pytype: disable=module-attr
|
||||
for x, d, s in safe_zip(batch_xs, batch_devs, batch_shardings)
|
||||
]
|
||||
else:
|
||||
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
|
||||
batch_xs, batch_devs, batch_shardings)
|
||||
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
|
||||
batch_xs, batch_devs, batch_shardings)
|
||||
for i, copy_out in safe_zip(batch_indices, copy_outs):
|
||||
assert results[i] is None
|
||||
results[i] = copy_out
|
||||
|
@ -62,7 +62,6 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
@ -3175,10 +3174,7 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
self.unsafe_call.in_handler.input_indices)
|
||||
else:
|
||||
fastpath_data = None
|
||||
if xla_extension_version > 267:
|
||||
return outs, fastpath_data, False # Do not remove cache entry
|
||||
else:
|
||||
return outs, fastpath_data
|
||||
return outs, fastpath_data, False # Do not remove cache entry
|
||||
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
|
@ -335,11 +335,6 @@ def _approx_top_k_lowering(ctx, operand, *, k,
|
||||
backend_config=backend_config,
|
||||
result_shapes=result_shapes)
|
||||
else:
|
||||
if xc.mlir_api_version < 57:
|
||||
raise NotImplementedError(
|
||||
"approx_top_k with non-constant k requires jaxlib version 0.4.29 or "
|
||||
"newer")
|
||||
|
||||
k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
|
||||
out = mlir.custom_call(
|
||||
"stablehlo.dynamic_approx_top_k",
|
||||
|
@ -50,7 +50,6 @@ from jax._src.lax.control_flow.common import (
|
||||
_abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
|
||||
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
|
||||
_typecheck_param)
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
@ -2370,27 +2369,12 @@ def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn):
|
||||
mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)),
|
||||
platform=platform)
|
||||
|
||||
if xla_extension_version >= 266:
|
||||
# For jax-metal, until reduce_window legalization is better supported.
|
||||
register_lowering(partial(associative_scan, reduce_fn), 'METAL')
|
||||
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
|
||||
register_lowering(
|
||||
partial(cumred_reduce_window_impl, reduce_window_fn)
|
||||
)
|
||||
else:
|
||||
# Older XLA versions only have this rewrite for TPU.
|
||||
register_lowering(
|
||||
partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu'
|
||||
)
|
||||
# Default for platforms not treated specially below.
|
||||
register_lowering(partial(associative_scan, reduce_fn))
|
||||
|
||||
# On GPU, we choose between window reduction and associative scan
|
||||
# based on the input size.
|
||||
for platform in ['cuda', 'rocm']:
|
||||
register_lowering(
|
||||
partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform
|
||||
)
|
||||
# For jax-metal, until reduce_window legalization is better supported.
|
||||
register_lowering(partial(associative_scan, reduce_fn), 'METAL')
|
||||
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
|
||||
register_lowering(
|
||||
partial(cumred_reduce_window_impl, reduce_window_fn)
|
||||
)
|
||||
|
||||
return reducer_p
|
||||
|
||||
|
@ -44,7 +44,6 @@ from jax._src.lib import gpu_linalg
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
@ -510,10 +509,6 @@ def _cholesky_update_cuda_lowering_rule(ctx, r_matrix, w_vector):
|
||||
raise NotImplementedError(
|
||||
"Can only lower fast cholesky_update on CUDA."
|
||||
)
|
||||
if jaxlib_version < (0, 4, 29):
|
||||
raise NotImplementedError(
|
||||
f"The jaxlib version {jaxlib_version} is too old."
|
||||
"Please update to at least 0.4.29.")
|
||||
return gpu_linalg.cuda_cholesky_update(
|
||||
r_matrix, w_vector, r_matrix_aval.dtype)
|
||||
|
||||
|
@ -60,13 +60,12 @@ from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
NamedSharding, GSPMDSharding,
|
||||
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
@ -333,10 +332,7 @@ def _cpp_pjit(jit_info: PjitInfo):
|
||||
jaxpr.consts, jit_info.abstracted_axes,
|
||||
pgle_profiler)
|
||||
|
||||
if xla_extension_version > 267:
|
||||
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
||||
else:
|
||||
return outs, maybe_fastpath_data
|
||||
return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
||||
|
||||
fun = jit_info.fun
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
@ -1403,10 +1399,6 @@ def _resolve_in_shardings(
|
||||
# not allow None as the sharding.
|
||||
if arg_s is None:
|
||||
continue
|
||||
if xla_extension_version < 270:
|
||||
if not isinstance(arg_s, XLACompatibleSharding):
|
||||
raise ValueError(f'One of the argument to pjit got sharding {arg_s} '
|
||||
'which is not a subclass of XLACompatibleSharding.')
|
||||
# Don't consider PmapSharding inputs as committed. They will get resharded
|
||||
# unconditionally.
|
||||
if isinstance(arg_s, PmapSharding):
|
||||
@ -1621,10 +1613,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
||||
jaxpr.consts, None, pgle_profiler)
|
||||
if xla_extension_version > 267:
|
||||
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
||||
else:
|
||||
return out_flat, fastpath_data
|
||||
return out_flat, fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
|
||||
|
@ -32,7 +32,6 @@ from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import core
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.op_shardings import (
|
||||
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -1054,10 +1053,6 @@ def prepare_axis_resources(axis_resources,
|
||||
if isinstance(entry, PmapSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not '
|
||||
'allowed.')
|
||||
if xla_extension_version < 270:
|
||||
if not isinstance(entry, XLACompatibleSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not a '
|
||||
'subclass of XLACompatibleSharding.')
|
||||
new_entries.append(entry)
|
||||
else:
|
||||
new_entries.append(ParsedPartitionSpec.from_user_input(
|
||||
|
@ -50,7 +50,6 @@ from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
||||
from jax._src.public_test_util import ( # noqa: F401
|
||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||
@ -245,32 +244,18 @@ def count_primitive_compiles():
|
||||
|
||||
@contextmanager
|
||||
def count_device_put_fast_path_hit():
|
||||
if xla_extension_version < 271:
|
||||
original_fn = xc.copy_array_to_devices_with_sharding
|
||||
count = [0]
|
||||
original_fn = xc.batched_copy_array_to_devices_with_sharding
|
||||
count = [0]
|
||||
|
||||
def copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_fn(*args, **kwargs)
|
||||
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
xc.copy_array_to_devices_with_sharding = copy_array_to_devices_with_sharding_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
xc.copy_array_to_devices_with_sharding = original_fn
|
||||
else:
|
||||
original_fn = xc.batched_copy_array_to_devices_with_sharding
|
||||
count = [0]
|
||||
|
||||
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
xc.batched_copy_array_to_devices_with_sharding = original_fn
|
||||
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
xc.batched_copy_array_to_devices_with_sharding = original_fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -45,7 +45,6 @@ from jax._src.cloud_tpu_init import get_tpu_library_path
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -147,13 +146,9 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
|
||||
t.start()
|
||||
|
||||
try:
|
||||
if xla_extension_version >= 267:
|
||||
client = xla_client.make_tpu_client( # type: ignore
|
||||
get_tpu_library_path(),
|
||||
_options_from_jax_configs("tpu"))
|
||||
else:
|
||||
client = xla_client.make_tpu_client(
|
||||
get_tpu_library_path())
|
||||
client = xla_client.make_tpu_client( # type: ignore
|
||||
get_tpu_library_path(),
|
||||
_options_from_jax_configs("tpu"))
|
||||
finally:
|
||||
t.cancel()
|
||||
|
||||
|
@ -133,7 +133,7 @@ def _get_cmdclass(pkg_source_path):
|
||||
|
||||
|
||||
__version__ = _get_version_string()
|
||||
_minimum_jaxlib_version = "0.4.27"
|
||||
_minimum_jaxlib_version = "0.4.30"
|
||||
|
||||
def _version_as_tuple(version_str):
|
||||
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
|
||||
|
@ -58,7 +58,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax._src.util as jax_util
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.custom_batching
|
||||
@ -707,7 +706,6 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
self.assertNotEqual(z3.unsafe_buffer_pointer(), x1.unsafe_buffer_pointer())
|
||||
self.assertEqual(z2, 1)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 264, "jaxlib version too old")
|
||||
def test_print_token_buffer_error(self):
|
||||
token = jax.lax.create_token()
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -30,7 +30,6 @@ from jax._src import op_shardings
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.sharding import common_devices_indices_map
|
||||
from jax._src.sharding_impls import (_op_sharding_to_pos_sharding,
|
||||
@ -1244,8 +1243,6 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, msg):
|
||||
jax.jit(f)(x)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 269,
|
||||
"Test requires jaxlib 0.4.29 or newer")
|
||||
def test_make_array_from_single_device_arrays_bad_inputs(self):
|
||||
x = jnp.arange(10)
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
|
@ -19,16 +19,11 @@ from jax import numpy as jnp
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lib import version as jaxlib_version # pylint: disable=g-importing-member
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
class CholeskyUpdateTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if jaxlib_version < (0, 4, 29):
|
||||
self.skipTest("Requires jaxlib 0.4.29 or newer")
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[
|
||||
|
@ -62,7 +62,6 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -713,8 +712,6 @@ class CompatTest(bctu.CompatTestBase):
|
||||
check_results=check_top_k_results)
|
||||
|
||||
def test_dynamic_approx_top_k(self):
|
||||
if xla_client.mlir_api_version < 57:
|
||||
self.skipTest("Requires newer jaxlib")
|
||||
# stablehlo.dynamic_approx_top_k is used temporarily for a approx_top_k
|
||||
# with dynamism
|
||||
# This is the input that was used to generate the test_data
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest, parameterized
|
||||
@ -31,7 +30,6 @@ from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.extend import ffi
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -93,7 +91,6 @@ class RandomTest(jtu.JaxTestCase):
|
||||
|
||||
class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 265, "Requires jaxlib 0.4.29")
|
||||
def testHeadersExist(self):
|
||||
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
|
||||
for header in ["c_api.h", "api.h", "ffi.h"]:
|
||||
|
@ -26,7 +26,6 @@ from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src import config
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
|
||||
import jax.numpy as jnp
|
||||
from jax.ad_checkpoint import Offloadable, remat, Recompute
|
||||
@ -68,8 +67,6 @@ def _create_inputs(shape, pspec, mem_kind=None):
|
||||
class ShardingMemoriesTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
if xla_extension_version < 265 and not jtu.test_device_matches(["tpu"]):
|
||||
self.skipTest("Memories do not work on CPU and GPU backends yet.")
|
||||
# TODO(b/311021572)
|
||||
if jtu.is_cloud_tpu():
|
||||
self.skipTest("Experimental feature not yet implemented on Cloud TPU")
|
||||
|
@ -32,7 +32,6 @@ from jax.experimental import profiler as exp_profiler
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding, PartitionSpec
|
||||
from jax._src import compilation_cache as cc
|
||||
from jax._src.lib import xla_extension_version
|
||||
import numpy as np
|
||||
|
||||
from jax.experimental.serialize_executable import (
|
||||
@ -55,9 +54,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skip("Test failing in CI")
|
||||
def testPGLEProfilerGetFDOProfile(self):
|
||||
if xla_extension_version < 268:
|
||||
return self.skipTest('Requires xla_extension_version >= 268')
|
||||
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
|
||||
@partial(
|
||||
@ -86,9 +82,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
self.assertIn(b'custom', fdo_profile)
|
||||
|
||||
def testAutoPgle(self):
|
||||
if xla_extension_version < 268:
|
||||
return self.skipTest('Requires xla_extension_version >= 268')
|
||||
|
||||
mesh = jtu.create_global_mesh((2,), ('x',))
|
||||
|
||||
@partial(
|
||||
@ -126,9 +119,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
self.assertEqual(cache_miss_count[0], 0)
|
||||
|
||||
def testAutoPgleWithAot(self):
|
||||
if xla_extension_version < 268:
|
||||
return self.skipTest('Requires xla_extension_version >= 268')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return x * 2
|
||||
@ -153,8 +143,6 @@ class PgleTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skip("Test failing in CI")
|
||||
def testAutoPgleWithPersistentCache(self):
|
||||
if xla_extension_version < 268:
|
||||
return self.skipTest('Requires xla_extension_version >= 268')
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
|
@ -2014,8 +2014,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
expect_error=(
|
||||
(NotImplementedError, "aggregate_to_topk=False") if (
|
||||
not agg and (isinstance(k, str) or
|
||||
isinstance(n, str))) else
|
||||
(NotImplementedError, "requires jaxlib version") if xla_client.mlir_api_version < 57 and agg and isinstance(k, str) else
|
||||
isinstance(n, str))) else
|
||||
None
|
||||
))
|
||||
for n in [8, "n"]
|
||||
|
@ -42,7 +42,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
from jax._src.lib import xla_extension_version
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
@ -1706,7 +1705,6 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "in_specs refers to 'j'"):
|
||||
f(v)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 262, "Requires jaxlib 0.4.28")
|
||||
def test_nested_partial_auto(self):
|
||||
mesh = jtu.create_global_mesh((2, 2), ('i', 'j'))
|
||||
|
||||
|
@ -27,7 +27,6 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -162,14 +161,9 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
def _mock_tpu_client(library_path=None):
|
||||
_mock_tpu_client_with_options(library_path=library_path, options=None)
|
||||
|
||||
if xla_extension_version >= 267:
|
||||
with mock.patch.object(xc, "make_tpu_client",
|
||||
side_effect=_mock_tpu_client_with_options):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
else:
|
||||
with mock.patch.object(xc, "make_tpu_client",
|
||||
side_effect=_mock_tpu_client):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
with mock.patch.object(xc, "make_tpu_client",
|
||||
side_effect=_mock_tpu_client_with_options):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
|
||||
def test_register_plugin(self):
|
||||
with self.assertLogs(level="WARNING") as log_output:
|
||||
@ -205,7 +199,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
self.assertTrue(registration.experimental)
|
||||
|
||||
options = {}
|
||||
if xb.get_backend().platform == 'tpu' and xla_extension_version >= 267:
|
||||
if xb.get_backend().platform == 'tpu':
|
||||
options["ml_framework_name"] = "JAX"
|
||||
options["ml_framework_version"] = version.__version__
|
||||
mock_make.assert_called_once_with("name1", options, None)
|
||||
@ -243,7 +237,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
"string_option": "string",
|
||||
"float_option": 1.0,
|
||||
}
|
||||
if xb.get_backend().platform == 'tpu' and xla_extension_version >= 267:
|
||||
if xb.get_backend().platform == 'tpu':
|
||||
options["ml_framework_name"] = "JAX"
|
||||
options["ml_framework_version"] = version.__version__
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user