Increase minimum Jaxlib version to 0.1.22.

Remove code that preserves backward compatibility with older jaxlib versions.
This commit is contained in:
Peter Hawkins 2019-07-23 21:45:41 -04:00
parent dd37f8c961
commit 2369d1fe61
4 changed files with 12 additions and 51 deletions

View File

@ -439,7 +439,7 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
def copy_to_host_async(self):
if self._npy_value is None:
for buf in self.device_buffers:
xla._copy_to_host_async(buf)
buf.copy_to_host_async()
def delete(self):
for buf in self.device_buffers:

View File

@ -145,12 +145,7 @@ def device_put(x, device_num=0):
if x.device_buffer.device() == device_num:
return x.device_buffer
else:
# TODO(phawkins): remove after the minimum Jaxlib version is raised to
# 0.1.22
if hasattr(x.device_buffer, 'copy_to_device'):
return x.device_buffer.copy_to_device(device_num)
else:
return device_put(x.device_buffer.to_py(), device_num)
return x.device_buffer.copy_to_device(device_num)
elif isinstance(x, DeviceConstant):
return _instantiate_device_constant(x, device_num=device_num)
elif isinstance(x, (DeviceArray, onp.ndarray)):
@ -510,13 +505,6 @@ xb.register_constant_handler(DeviceTuple, _device_tuple_constant_handler)
# TODO(mattjj): could jit-compile a computation here
ad_util.jaxval_adders[DeviceTuple] = ad_util.add_jaxtuples
# TODO(phawkins): after Jaxlib 0.1.17 has been released, bump the minimum
# jaxlib version and change callers of this function to simply call
# the copy_to_host_async method directly.
def _copy_to_host_async(buffer):
if hasattr(buffer, "copy_to_host_async"):
buffer.copy_to_host_async()
def _forward_method(attrname, self, fun, *args):
return fun(getattr(self, attrname), *args)
@ -550,7 +538,7 @@ class DeviceArray(DeviceValue):
"""Requests a copy of the buffer to the host."""
self._check_if_deleted()
if self._npy_value is None:
_copy_to_host_async(self.device_buffer)
self.device_buffer.copy_to_host_async()
def delete(self):
"""Deletes the device array and any cached copy on the host.

View File

@ -387,17 +387,7 @@ def concatenate(operands, dimension):
return concatenate_p.bind(*operands, dimension=dimension,
operand_shapes=tuple(o.shape for o in operands))
# TODO(phawkins): remove once the minimum Jaxlib version is increased to 0.1.22.
_supports_precision = hasattr(xla_client, "PrecisionConfig")
if _supports_precision:
Precision = xla_client.PrecisionConfig.Precision
else:
# Dummy for backward compatibility with older Jaxlib versions.
class Precision(enum.Enum):
DEFAULT = 0
HIGH = 1
HIGHEST = 2
Precision = xla_client.PrecisionConfig.Precision
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
rhs_dilation=None, dimension_numbers=None,
@ -1932,10 +1922,10 @@ def _conv_general_dilated_translation_rule(
dimension_numbers, feature_group_count, precision, **unused_kwargs):
assert type(dimension_numbers) is ConvDimensionNumbers
dimension_numbers = _conv_general_proto(dimension_numbers)
kwargs = _precision_config_kwargs(precision)
return c.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers,
feature_group_count, **kwargs)
feature_group_count,
precision_config=_precision_config(precision))
def _conv_general_dilated_batch_rule(
batched_args, batch_dims, window_strides, padding,
@ -2120,18 +2110,15 @@ def _dot_batch_rule(batched_args, batch_dims, precision=None):
dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)]
return dot_general(lhs, rhs, dim_nums, precision=precision), 0
# TODO(phawkins): pass precision_config directly after the minimum Jaxlib
# version has been increased to 0.1.22.
def _precision_config_kwargs(precision):
kwargs = {}
def _precision_config(precision):
if precision is not None:
config = xla_client.PrecisionConfig()
config.operand_precision.extend((precision, precision))
kwargs["precision_config"] = config
return kwargs
return config
return None
def _dot_translation_rule(c, lhs, rhs, precision):
return c.Dot(lhs, rhs, **_precision_config_kwargs(precision))
return c.Dot(lhs, rhs, precision_config=_precision_config(precision))
_dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
dot_p = standard_primitive(_dot_shape_rule, _dot_dtype_rule, 'dot',
@ -2239,7 +2226,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers,
def _dot_general_translation_rule(c, lhs, rhs, dimension_numbers, precision):
return c.DotGeneral(lhs, rhs, dimension_numbers,
**_precision_config_kwargs(precision))
precision_config=_precision_config(precision))
dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general',
@ -4222,10 +4209,6 @@ def remaining(original, *removed_lists):
def _canonicalize_precision(precision):
if precision is None:
return None
if not _supports_precision:
warnings.warn("Precision specifications require Jaxlib >= 0.1.22; ignoring "
"precision specification")
return None
if isinstance(precision, Precision):
return precision
else:

View File

@ -34,7 +34,7 @@ import six
import jaxlib
_minimum_jaxlib_version = (0, 1, 14)
_minimum_jaxlib_version = (0, 1, 22)
try:
from jaxlib import version as jaxlib_version
except:
@ -57,14 +57,6 @@ _check_jaxlib_version()
from jaxlib import xla_client
from jaxlib import xla_data_pb2
# TODO(phawkins): This is a workaround for older jaxlib versions. Remove after a
# jaxlib release.
try:
from jaxlib import xrt
except ImportError:
xrt = None
FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_enable_x64',
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
@ -130,8 +122,6 @@ def _get_xrt_backend():
worker = "tpu_worker"
tf_context = xrt.get_tf_context(FLAGS.jax_backend_target, worker)
backend = xrt.XrtBackend(tf_context, tf_device_name)
# TODO(phawkins) fix XrtBackend to set the following and remove this line.
backend.platform = "TPU"
return backend