Increase minimum jaxlib version to 0.1.74.

This commit is contained in:
Peter Hawkins 2021-11-18 14:55:19 -05:00
parent 52421da0f6
commit 3fd3c46f20
15 changed files with 24 additions and 155 deletions

View File

@ -109,8 +109,7 @@ flags.DEFINE_bool(
"Set this to `False` only if it crashes otherwise and report "
"the error to the jax-team.")
flags.DEFINE_bool(
"experimental_cpp_pmap",
bool_env("JAX_CPP_PMAP", jax._src.lib._xla_extension_version >= 39),
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
@ -2014,8 +2013,7 @@ def _cpp_pmap(
return out, fastpath_data
# TODO(slebedev): Remove the ignore once jaxlib>=0.1.71.
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss, # type: ignore[call-arg]
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss,
static_broadcasted_tuple, pxla._shard_arg)
f_pmapped = wraps(fun)(cpp_mapped_f)

View File

@ -55,7 +55,6 @@ import jax._src.lib
from jax._src.lib import pytree
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import version as jaxlib_version
xb = xla_bridge
xc = xla_client
@ -2801,21 +2800,7 @@ ad.defjvp2(rsqrt_p,
lambda g, ans, x:
mul(g, mul(_const(x, -0.5), div(ans, x))))
# TODO(phawkins): remove the fallback translation rule after the minimum jaxlib
# is 0.1.70 or newer.
if jax._src.lib._xla_extension_version >= 28:
_cbrt_translation_rule = None
else:
def _cbrt_translation_rule(ctx, avals_in, avals_out, x):
x_aval, = avals_in
return [xops.Mul(
xops.Sign(x),
xops.Pow(xops.Abs(x),
xla.pyval_to_ir_constant(ctx.builder,
np.array(1/3, dtype=x_aval.dtype))))]
cbrt_p = standard_unop(_float, 'cbrt',
translation_rule=_cbrt_translation_rule)
cbrt_p = standard_unop(_float, 'cbrt')
ad.defjvp2(cbrt_p,
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
@ -5900,15 +5885,6 @@ def _compute_argminmax(value_comparator, get_identity,
axes)
return res[1]
def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
axis, = axes
idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis))
maxval = np.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype)
maxval = broadcast(tie_in(a, maxval), a.shape)
maxvals = expand_dims(op(a, (axis,)), (axis,))
mask_idxs = select(eq(a, maxvals) | ne(a, a), idxs, maxval)
return _reduce_min(mask_idxs, (axis,))
_argmin_translation_rule = xla.lower_fun(
partial(_compute_argminmax, lt, _get_min_identity),
multiple_results=False, new_style=True)
@ -5922,28 +5898,12 @@ argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
weak_type_rule=_strip_weak_type)
batching.defreducer(argmin_p)
ad.defjvp_zero(argmin_p)
if jax._src.lib._xla_extension_version < 41:
xla.register_translation(
argmin_p,
xla.lower_fun(
partial(_argminmax_gpu_translation_rule, _reduce_min),
multiple_results=False,
new_style=True),
platform='gpu')
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
'argmax', _argmax_translation_rule,
weak_type_rule=_strip_weak_type)
batching.defreducer(argmax_p)
ad.defjvp_zero(argmax_p)
if jax._src.lib._xla_extension_version < 41:
xla.register_translation(
argmax_p,
xla.lower_fun(
partial(_argminmax_gpu_translation_rule, _reduce_max),
multiple_results=False,
new_style=True),
platform='gpu')
def _reduce_logical_shape_rule(operand, *, axes):
@ -6927,7 +6887,6 @@ def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
def _rng_bit_generator_translation_rule(
ctx, avals_in, avals_out, key, *, shape, dtype, algorithm):
c = ctx.builder
backend_is_gpu = ctx.platform == "gpu"
key_shape, key_dtype = c.get_shape(key).dimensions(), c.get_shape(key).numpy_dtype()
# While the RngBitGenerator HLO accepts a u64[2] key on all backends, we
# typically represent the key argument to this primitive as a u32[4] so as to
@ -6938,48 +6897,15 @@ def _rng_bit_generator_translation_rule(
(key_shape == (2,) and key_dtype == np.dtype('uint64'))), (key_shape, key_dtype)
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
if key_dtype == np.dtype('uint32'):
# TODO(mattjj): the BitcastConvertType segfaults on GPU
# TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
else:
key = _convert_4xU32_to_2xU64_without_bitcast(c, key)
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
out_key, out_vals = xla.xla_destructure(
c, xops.RngBitGenerator(algorithm, key, xla_shape))
if key_dtype == np.dtype('uint32'):
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
else:
out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key)
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
return [out_key, out_vals]
def _convert_4xU32_to_2xU64_without_bitcast(c, key):
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
new_key = xops.Constant(c, np.zeros(2, dtype=np.dtype('uint64')))
_32 = xops.Constant(c, np.array(32, np.uint64))
for i in [0, 2]:
hi = xops.ConvertElementType(xops.Slice(key, [i] , [i+1], [1]), u64_etype)
lo = xops.ConvertElementType(xops.Slice(key, [i+1], [i+2], [1]), u64_etype)
elt = xops.Xor(xops.ShiftLeft(hi, _32), lo)
new_key = xops.DynamicUpdateSlice(new_key, elt,
[xla.pyval_to_ir_constant(c, i // 2)])
return new_key
def _convert_2xU64_to_4xU32_without_bitcast(c, key):
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
new_key = xops.Constant(c, np.zeros(4, dtype=np.dtype('uint32')))
_32 = xops.Constant(c, np.array(32, np.uint64))
for i in [0, 1]:
elt = xops.Slice(key, [i], [i+1], [1])
hi = xops.ConvertElementType(xops.ShiftRightLogical(elt, _32), u32_etype)
lo = xops.ConvertElementType(elt, u32_etype)
new_key = xops.DynamicUpdateSlice(new_key, hi,
[xla.pyval_to_ir_constant(c, 2 * i)])
new_key = xops.DynamicUpdateSlice(new_key, lo,
[xla.pyval_to_ir_constant(c, 2 * i + 1)])
return new_key
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]

View File

@ -41,7 +41,6 @@ from jax._src.lib import cusparse
from jax._src.lib import rocsolver
from jax._src.lib import xla_client
from jax._src.lib import version as jaxlib_version
xops = xla_client.ops
@ -1538,11 +1537,6 @@ def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
batch_dims = operand_aval.shape[:-2]
c = ctx.builder
if jaxlib_version < (0, 1, 72):
raise NotImplementedError(
"The Schur primitive is only implemented for jaxlib versions >= 0.1.72"
)
_cpu_gees = lapack.gees
if sort_eig_vals:

View File

@ -120,10 +120,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
hash_obj.update(compile_options_obj.device_assignment.serialize())
def _hash_executable_build_options(hash_obj, executable_obj):
if jax._src.lib.version >= (0, 1, 72):
expected_options = 31
else:
expected_options = 30
expected_options = 31
assert len(dir(executable_obj)) == expected_options, (
f"Unexpected number of executable_build_options fields: "
f"{len(dir(executable_obj))}. This likely means that an extra "
@ -136,8 +133,7 @@ def _hash_executable_build_options(hash_obj, executable_obj):
if executable_obj.device_assignment is not None:
hash_obj.update(executable_obj.device_assignment.serialize())
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
if jax._src.lib.version >= (0, 1, 72):
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)
def _hash_debug_options(hash_obj, debug_obj):
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math)

View File

@ -1674,20 +1674,13 @@ translations[lax.rng_uniform_p] = _rng_uniform_lowering
# xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
# if key_dtype == np.dtype('uint32'):
# # TODO(mattjj): the BitcastConvertType segfaults on GPU
# # TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer
# if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
# u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
# key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
# else:
# key = _convert_4xU32_to_2xU64_without_bitcast(c, key)
# u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
# key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
# out_key, out_vals = xla.xla_destructure(
# c, xops.RngBitGenerator(algorithm, key, xla_shape))
# if key_dtype == np.dtype('uint32'):
# if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
# u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
# out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
# else:
# out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key)
# u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
# out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
# return [out_key, out_vals]

View File

@ -34,14 +34,13 @@ from functools import partial
import itertools as it
import operator as op
import threading
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional,
from typing import (Any, Callable, Dict, List, Optional,
Sequence, Set, Tuple, Type, Union, Iterable)
import sys
from absl import logging
import numpy as np
import jax
from .._src.config import config
from .. import core
from .. import linear_util as lu
@ -55,7 +54,6 @@ from ..errors import JAXTypeError
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import _xla_extension_version
from ..tree_util import tree_flatten, tree_map
from . import batching
from . import partial_eval as pe
@ -97,9 +95,7 @@ MeshDimAssignment = Union[ShardedAxis, Replicated]
# mypy will consider this constant to be True at type check time.
MYPY = False
# TODO(jblespiau): Remove the version check when jaxlib 0.1.70 is the minimal
# version.
if MYPY or (not TYPE_CHECKING and _xla_extension_version < 30):
if MYPY:
class ShardingSpec:
"""Describes the sharding of an ndarray.
@ -503,8 +499,8 @@ def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
### lazy device-memory persistence and result handling
# TODO(jblespiau): Remove when jaxlib 0.1.72 is the minimal version.
_USE_CPP_SDA = _xla_extension_version >= 38
# TODO(jblespiau): Consider removing this option.
_USE_CPP_SDA = True
def make_sharded_device_array(
@ -539,9 +535,8 @@ def make_sharded_device_array(
if (_USE_CPP_SDA and
(not device_buffers or
isinstance(device_buffers[0], xb.xla_client.Buffer))):
# TODO(slebedev): Remove the ignore once jaxlib>=0.1.71.
return pmap_lib.ShardedDeviceArray.make(
aval, sharding_spec, device_buffers, # type: ignore[arg-type, call-arg]
aval, sharding_spec, device_buffers,
indices, aval.weak_type)
return _ShardedDeviceArray(aval, sharding_spec, device_buffers, indices)
@ -1783,12 +1778,8 @@ class MeshExecutable:
use_spmd_partitioning=spmd_lowering,
)
compile_options.parameter_is_tupled_arguments = tuple_args
if jax._src.lib.version >= (0, 1, 72):
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
_allow_propagation_to_outputs
elif _allow_propagation_to_outputs:
raise RuntimeError("Propagation of SPMD sharding specs to outputs is only supported "
"in jaxlib 0.1.72+. Please update your JAX version.")
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
_allow_propagation_to_outputs
local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names)
local_input_specs = [local_sharding_spec(aval, aval_in_axes)

View File

@ -13,4 +13,4 @@
# limitations under the License.
__version__ = "0.2.26"
_minimum_jaxlib_version = "0.1.69"
_minimum_jaxlib_version = "0.1.74"

View File

@ -82,16 +82,12 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# Tensorflow.
return api._cpp_jit
@unittest.skipIf(jax._src.lib._xla_extension_version < 40,
"Test requires jaxlib 0.1.73")
def test_jit_repr(self):
def my_function():
return
jitted = jit(my_function)
self.assertEqual(repr(jitted), f"<CompiledFunction of {repr(my_function)}>")
@unittest.skipIf(jax._src.lib._xla_extension_version < 40,
"Test requires jaxlib 0.1.73")
def test_jit_repr_errors(self):
class Callable:
def __call__(self): pass
@ -692,8 +688,6 @@ class CPPJitTest(jtu.BufferDonationTestCase):
np.testing.assert_allclose(f_pruned(*args), 3)
self.assertEqual(count[0], 1)
@unittest.skipIf(jax._src.lib._xla_extension_version <= 36,
"Test requires jaxlib 0.1.71")
def testBuffersAreFreedPromptly(self):
# Regression test for a bug where garbage collection was delayed too long
# for NumPy buffers that are aliased zero-copy by the runtime.

View File

@ -97,9 +97,7 @@ class DebugNaNsTest(jtu.JaxTestCase):
f(1)
def testPmap(self):
pmap_funcs = [api._python_pmap]
if jax._src.lib._xla_extension_version >= 36:
pmap_funcs.append(api._cpp_pmap)
pmap_funcs = [api._python_pmap, api._cpp_pmap]
for pmap in pmap_funcs:
f = pmap(lambda x: 0. / x)

View File

@ -19,7 +19,6 @@ from functools import partial
import itertools
import typing
from typing import Any, Optional, Tuple
import unittest
import warnings
from absl.testing import absltest
@ -1241,8 +1240,6 @@ class IndexedUpdateTest(jtu.JaxTestCase):
with self.assertNoWarnings():
x.at[normalize(idx)].set(0)
@unittest.skipIf(jax._src.lib.version < (0, 1, 72),
"Bug fixed in jaxlib 0.1.72")
def testIndexedUpdateAliasingBug(self):
# https://github.com/google/jax/issues/7461
fn = lambda x: x.at[1:].set(1 + x[:-1])

View File

@ -1453,8 +1453,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSchur(self, shape, dtype):
if jax._src.lib.version < (0, 1, 72):
self.skipTest("Schur LAPACK wrapper only implemented for jaxlib versions >= 0.1.72")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@ -1470,8 +1468,6 @@ class LaxLinalgTest(jtu.JaxTestCase):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSchurBatching(self, shape, dtype):
if jax._src.lib.version < (0, 1, 72):
self.skipTest("Schur LAPACK wrapper only implemented for jaxlib versions >= 0.1.72")
rng = jtu.rand_default(self.rng())
batch_size = 10
shape = (batch_size, ) + shape

View File

@ -35,8 +35,6 @@ config.parse_flags_with_absl()
class CloudpickleTest(jtu.JaxTestCase):
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
@unittest.skipIf(jax._src.lib._xla_extension_version < 31,
"Requires jaxlib 0.1.71")
def testPickleOfJittedFunctions(self):
@jax.jit
@ -56,8 +54,6 @@ class CloudpickleTest(jtu.JaxTestCase):
self.assertEqual(expected, actual)
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
@unittest.skipIf(jax._src.lib._xla_extension_version < 39,
"Requires jaxlib 0.1.72")
def testPickleOfPmappedFunctions(self):
@jax.pmap

View File

@ -266,7 +266,6 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertTrue(hasattr(y, "sharding_spec"))
@check_1d_2d_mesh(set_mesh=True)
@unittest.skipIf(jax._src.lib.version < (0, 1, 72), "Needs jaxlib 0.1.72+")
def testAutodiff(self, mesh, resources):
if len(mesh) != 2: return
assert resources == ('x', 'y')

View File

@ -1821,10 +1821,7 @@ class CppPmapTest(PythonPmapTest):
@property
def pmap(self):
if jax._src.lib._xla_extension_version >= 38:
return src_api._cpp_pmap
else:
return src_api._python_pmap
return src_api._cpp_pmap
class VmapOfPmapTest(jtu.JaxTestCase):

View File

@ -20,13 +20,12 @@ import re
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import test_util as jtu
from jax import tree_util
from jax._src.tree_util import _process_pytree
from jax import flatten_util
import jax.numpy as jnp
import jax._src.lib
def _dummy_func(*args, **kwargs):
return
@ -212,9 +211,6 @@ class TreeTest(jtu.JaxTestCase):
def testTreedefTupleFromChildren(self):
# https://github.com/google/jax/issues/7377
# TODO(frostig): remove after the minimum jaxlib is is 0.1.70 or newer.
if jax._src.lib._xla_extension_version < 29:
self.skipTest("fixed in future jaxlib")
tree = ((1, 2, (3, 4)), (5,))
leaves, treedef1 = tree_util.tree_flatten(tree)
treedef2 = tree_util.treedef_tuple(treedef1.children())
@ -331,8 +327,6 @@ class TreeTest(jtu.JaxTestCase):
self.assertRegex(str(treedef), correct_string)
def testTreeDefWithEmptyDictStringRepresentation(self):
if jax._src.lib._xla_extension_version < 35:
self.skipTest("fixed in future jaxlib")
self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")