mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Increase minimum Jaxlib version to 0.1.22.
Remove code that preserves backward compatibility with older jaxlib versions.
This commit is contained in:
parent
dd37f8c961
commit
2369d1fe61
@ -439,7 +439,7 @@ class ShardedDeviceArray(ShardedDeviceValue, xla.DeviceArray):
|
||||
def copy_to_host_async(self):
|
||||
if self._npy_value is None:
|
||||
for buf in self.device_buffers:
|
||||
xla._copy_to_host_async(buf)
|
||||
buf.copy_to_host_async()
|
||||
|
||||
def delete(self):
|
||||
for buf in self.device_buffers:
|
||||
|
@ -145,12 +145,7 @@ def device_put(x, device_num=0):
|
||||
if x.device_buffer.device() == device_num:
|
||||
return x.device_buffer
|
||||
else:
|
||||
# TODO(phawkins): remove after the minimum Jaxlib version is raised to
|
||||
# 0.1.22
|
||||
if hasattr(x.device_buffer, 'copy_to_device'):
|
||||
return x.device_buffer.copy_to_device(device_num)
|
||||
else:
|
||||
return device_put(x.device_buffer.to_py(), device_num)
|
||||
return x.device_buffer.copy_to_device(device_num)
|
||||
elif isinstance(x, DeviceConstant):
|
||||
return _instantiate_device_constant(x, device_num=device_num)
|
||||
elif isinstance(x, (DeviceArray, onp.ndarray)):
|
||||
@ -510,13 +505,6 @@ xb.register_constant_handler(DeviceTuple, _device_tuple_constant_handler)
|
||||
# TODO(mattjj): could jit-compile a computation here
|
||||
ad_util.jaxval_adders[DeviceTuple] = ad_util.add_jaxtuples
|
||||
|
||||
# TODO(phawkins): after Jaxlib 0.1.17 has been released, bump the minimum
|
||||
# jaxlib version and change callers of this function to simply call
|
||||
# the copy_to_host_async method directly.
|
||||
def _copy_to_host_async(buffer):
|
||||
if hasattr(buffer, "copy_to_host_async"):
|
||||
buffer.copy_to_host_async()
|
||||
|
||||
|
||||
def _forward_method(attrname, self, fun, *args):
|
||||
return fun(getattr(self, attrname), *args)
|
||||
@ -550,7 +538,7 @@ class DeviceArray(DeviceValue):
|
||||
"""Requests a copy of the buffer to the host."""
|
||||
self._check_if_deleted()
|
||||
if self._npy_value is None:
|
||||
_copy_to_host_async(self.device_buffer)
|
||||
self.device_buffer.copy_to_host_async()
|
||||
|
||||
def delete(self):
|
||||
"""Deletes the device array and any cached copy on the host.
|
||||
|
@ -387,17 +387,7 @@ def concatenate(operands, dimension):
|
||||
return concatenate_p.bind(*operands, dimension=dimension,
|
||||
operand_shapes=tuple(o.shape for o in operands))
|
||||
|
||||
# TODO(phawkins): remove once the minimum Jaxlib version is increased to 0.1.22.
|
||||
_supports_precision = hasattr(xla_client, "PrecisionConfig")
|
||||
|
||||
if _supports_precision:
|
||||
Precision = xla_client.PrecisionConfig.Precision
|
||||
else:
|
||||
# Dummy for backward compatibility with older Jaxlib versions.
|
||||
class Precision(enum.Enum):
|
||||
DEFAULT = 0
|
||||
HIGH = 1
|
||||
HIGHEST = 2
|
||||
Precision = xla_client.PrecisionConfig.Precision
|
||||
|
||||
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
|
||||
rhs_dilation=None, dimension_numbers=None,
|
||||
@ -1932,10 +1922,10 @@ def _conv_general_dilated_translation_rule(
|
||||
dimension_numbers, feature_group_count, precision, **unused_kwargs):
|
||||
assert type(dimension_numbers) is ConvDimensionNumbers
|
||||
dimension_numbers = _conv_general_proto(dimension_numbers)
|
||||
kwargs = _precision_config_kwargs(precision)
|
||||
return c.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers,
|
||||
feature_group_count, **kwargs)
|
||||
feature_group_count,
|
||||
precision_config=_precision_config(precision))
|
||||
|
||||
def _conv_general_dilated_batch_rule(
|
||||
batched_args, batch_dims, window_strides, padding,
|
||||
@ -2120,18 +2110,15 @@ def _dot_batch_rule(batched_args, batch_dims, precision=None):
|
||||
dim_nums = [(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch)]
|
||||
return dot_general(lhs, rhs, dim_nums, precision=precision), 0
|
||||
|
||||
# TODO(phawkins): pass precision_config directly after the minimum Jaxlib
|
||||
# version has been increased to 0.1.22.
|
||||
def _precision_config_kwargs(precision):
|
||||
kwargs = {}
|
||||
def _precision_config(precision):
|
||||
if precision is not None:
|
||||
config = xla_client.PrecisionConfig()
|
||||
config.operand_precision.extend((precision, precision))
|
||||
kwargs["precision_config"] = config
|
||||
return kwargs
|
||||
return config
|
||||
return None
|
||||
|
||||
def _dot_translation_rule(c, lhs, rhs, precision):
|
||||
return c.Dot(lhs, rhs, **_precision_config_kwargs(precision))
|
||||
return c.Dot(lhs, rhs, precision_config=_precision_config(precision))
|
||||
|
||||
_dot_dtype_rule = partial(binop_dtype_rule, _input_dtype, [_num, _num], 'dot')
|
||||
dot_p = standard_primitive(_dot_shape_rule, _dot_dtype_rule, 'dot',
|
||||
@ -2239,7 +2226,7 @@ def _dot_general_batch_rule(batched_args, batch_dims, dimension_numbers,
|
||||
|
||||
def _dot_general_translation_rule(c, lhs, rhs, dimension_numbers, precision):
|
||||
return c.DotGeneral(lhs, rhs, dimension_numbers,
|
||||
**_precision_config_kwargs(precision))
|
||||
precision_config=_precision_config(precision))
|
||||
|
||||
dot_general_p = standard_primitive(_dot_general_shape_rule,
|
||||
_dot_general_dtype_rule, 'dot_general',
|
||||
@ -4222,10 +4209,6 @@ def remaining(original, *removed_lists):
|
||||
def _canonicalize_precision(precision):
|
||||
if precision is None:
|
||||
return None
|
||||
if not _supports_precision:
|
||||
warnings.warn("Precision specifications require Jaxlib >= 0.1.22; ignoring "
|
||||
"precision specification")
|
||||
return None
|
||||
if isinstance(precision, Precision):
|
||||
return precision
|
||||
else:
|
||||
|
@ -34,7 +34,7 @@ import six
|
||||
|
||||
import jaxlib
|
||||
|
||||
_minimum_jaxlib_version = (0, 1, 14)
|
||||
_minimum_jaxlib_version = (0, 1, 22)
|
||||
try:
|
||||
from jaxlib import version as jaxlib_version
|
||||
except:
|
||||
@ -57,14 +57,6 @@ _check_jaxlib_version()
|
||||
from jaxlib import xla_client
|
||||
from jaxlib import xla_data_pb2
|
||||
|
||||
# TODO(phawkins): This is a workaround for older jaxlib versions. Remove after a
|
||||
# jaxlib release.
|
||||
try:
|
||||
from jaxlib import xrt
|
||||
except ImportError:
|
||||
xrt = None
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_bool('jax_enable_x64',
|
||||
strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
|
||||
@ -130,8 +122,6 @@ def _get_xrt_backend():
|
||||
worker = "tpu_worker"
|
||||
tf_context = xrt.get_tf_context(FLAGS.jax_backend_target, worker)
|
||||
backend = xrt.XrtBackend(tf_context, tf_device_name)
|
||||
# TODO(phawkins) fix XrtBackend to set the following and remove this line.
|
||||
backend.platform = "TPU"
|
||||
return backend
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user