Copybara import of the project:

--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
This commit is contained in:
George Necula 2021-05-12 02:29:51 -07:00 committed by jax authors
parent aa74314c1a
commit 235eb8c2b4
15 changed files with 212 additions and 90 deletions

View File

@ -11,9 +11,15 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.14 (unreleased) ## jax 0.2.14 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master). * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
* Bug fixes:
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
({jax-issue}`#6717`).
## jaxlib 0.1.67 (unreleased) ## jaxlib 0.1.67 (unreleased)
## jaxlib 0.1.66 (May 11 2021) ## jaxlib 0.1.66 (May 11 2021)
* New features: * New features:
* CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher. * CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.

View File

@ -6652,28 +6652,35 @@ def remaining(original, *removed_lists):
return [i for i in original if i not in removed] return [i for i in original if i not in removed]
def _canonicalize_precision(precision): def _canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[PrecisionType, PrecisionType]]:
"""Turns an API precision specification, into a pair of enumeration values.
The API can take the precision as a string, or int, and either as a single
value to apply to both operands, or as a sequence of two values.
"""
if precision is None: if precision is None:
if config.jax_default_matmul_precision is None: if config.jax_default_matmul_precision is None:
return None return None
try: try:
return _precision_strings[config.jax_default_matmul_precision] precision = _precision_strings[config.jax_default_matmul_precision]
return (precision, precision)
except KeyError: except KeyError:
raise ValueError( raise ValueError(
"jax_default_matmul_precision flag must be set to None or a value in " "jax_default_matmul_precision flag must be set to None or a value in "
f"{_precision_strings}, but got {config.jax_default_matmul_precision}" f"{_precision_strings}, but got {config.jax_default_matmul_precision}"
) from None ) from None
elif isinstance(precision, str) and precision in _precision_strings: elif isinstance(precision, str) and precision in _precision_strings:
return _precision_strings.get(precision) precision = _precision_strings.get(precision)
return (precision, precision)
elif isinstance(precision, Precision): elif isinstance(precision, Precision):
return precision return (precision, precision)
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(p, Precision) for p in precision)): all(isinstance(p, Precision) for p in precision)):
return precision return precision # type: ignore[return-value]
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(s, str) for s in precision)): all(isinstance(s, str) for s in precision)):
s1, s2 = precision s1, s2 = precision
return (_canonicalize_precision(s1), _canonicalize_precision(s2)) return (_canonicalize_precision(s1)[0], _canonicalize_precision(s2)[0]) # type: ignore
else: else:
raise ValueError( raise ValueError(
f"Precision argument must be None, a string in {_precision_strings}, " f"Precision argument must be None, a string in {_precision_strings}, "

View File

@ -456,6 +456,7 @@ We use the following TFXLA ops:
* `XlaPad` (wraps XLA Pad operator). We use this instead of `tf.pad` in order to * `XlaPad` (wraps XLA Pad operator). We use this instead of `tf.pad` in order to
support `lax.pad` interior padding (dilation) or negative edge padding. support `lax.pad` interior padding (dilation) or negative edge padding.
* `XlaConv` (wraps XLA ConvGeneralDilated operator). * `XlaConv` (wraps XLA ConvGeneralDilated operator).
* `XlaDot` and `XlaDotV2` (wraps XLA DotGeneral operator).
* `XlaGather` (wraps XLA Gather operator). We could use `tf.gather` in some * `XlaGather` (wraps XLA Gather operator). We could use `tf.gather` in some
cases but not always. Also, `tf.gather` has a different semantics than `lax.gather` cases but not always. Also, `tf.gather` has a different semantics than `lax.gather`
for index out of bounds. for index out of bounds.

View File

@ -15,7 +15,7 @@
See README.md for instructions. See README.md for instructions.
""" """
import grpc import grpc # type: ignore[import]
import json import json
import logging import logging
import requests import requests
@ -26,9 +26,9 @@ from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib # type: ignore from jax.experimental.jax2tf.examples import mnist_lib # type: ignore
import numpy as np import numpy as np
import tensorflow as tf # type: ignore import tensorflow as tf # type: ignore[import]
import tensorflow_datasets as tfds # type: ignore import tensorflow_datasets as tfds # type: ignore[import]
from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import predict_pb2 # type: ignore[import]
from tensorflow_serving.apis import prediction_service_pb2_grpc from tensorflow_serving.apis import prediction_service_pb2_grpc

View File

@ -21,8 +21,8 @@ from jax.experimental.jax2tf.examples import mnist_lib
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf # type: ignore[import]
import tensorflow_datasets as tfds import tensorflow_datasets as tfds # type: ignore[import]
flags.DEFINE_string('tflite_file_path', flags.DEFINE_string('tflite_file_path',
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite', '/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite',

View File

@ -1,11 +1,11 @@
# Primitives with limited JAX support # Primitives with limited JAX support
*Last generated on: 2021-03-29* (YYYY-MM-DD) *Last generated on: 2021-05-12* (YYYY-MM-DD)
## Supported data types for primitives ## Supported data types for primitives
We use a set of 2313 test harnesses to test We use a set of 2418 test harnesses to test
the implementation of 122 numeric JAX primitives. the implementation of 121 numeric JAX primitives.
We consider a JAX primitive supported for a particular data We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type. type if it is supported on at least one device type.
The following table shows the dtypes at which primitives The following table shows the dtypes at which primitives
@ -76,7 +76,7 @@ be updated.
| device_put | 16 | all | | | device_put | 16 | all | |
| digamma | 4 | floating | bool, complex, integer | | digamma | 4 | floating | bool, complex, integer |
| div | 20 | inexact, integer | bool | | div | 20 | inexact, integer | bool |
| dot_general | 125 | all | | | dot_general | 245 | all | |
| dynamic_slice | 32 | all | | | dynamic_slice | 32 | all | |
| dynamic_update_slice | 21 | all | | | dynamic_update_slice | 21 | all | |
| eig | 72 | inexact | bool, integer | | eig | 72 | inexact | bool, integer |
@ -156,7 +156,6 @@ be updated.
| svd | 120 | inexact | bool, integer | | svd | 120 | inexact | bool, integer |
| tan | 6 | inexact | bool, integer | | tan | 6 | inexact | bool, integer |
| tanh | 6 | inexact | bool, integer | | tanh | 6 | inexact | bool, integer |
| tie_in | 15 | all | |
| top_k | 15 | bool, floating, integer | complex | | top_k | 15 | bool, floating, integer | complex |
| transpose | 17 | all | | | transpose | 17 | all | |
| triangular_solve | 26 | inexact | bool, integer | | triangular_solve | 26 | inexact | bool, integer |
@ -188,6 +187,9 @@ and search for "limitation".
|cummax|unimplemented|complex64|tpu| |cummax|unimplemented|complex64|tpu|
|cummin|unimplemented|complex64|tpu| |cummin|unimplemented|complex64|tpu|
|cumprod|unimplemented|complex64|tpu| |cumprod|unimplemented|complex64|tpu|
|dot_general|preferred_element_type=c128 not implemented|complex64|tpu|
|dot_general|preferred_element_type=f64 crashes (b/187884887)|bfloat16, float16, float32|tpu|
|dot_general|preferred_element_type=i64 not implemented|int16, int32, int8|tpu|
|eig|only supported on CPU in JAX|all|tpu, gpu| |eig|only supported on CPU in JAX|all|tpu, gpu|
|eig|unimplemented|bfloat16, float16|cpu| |eig|unimplemented|bfloat16, float16|cpu|
|eigh|complex eigh not supported |complex|tpu| |eigh|complex eigh not supported |complex|tpu|
@ -202,7 +204,6 @@ and search for "limitation".
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu| |select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu| |svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|svd|unimplemented|bfloat16, float16|cpu, gpu| |svd|unimplemented|bfloat16, float16|cpu, gpu|
|tie_in|requires omnistaging to be disabled|all|cpu, gpu, tpu|
|triangular_solve|unimplemented|float16|gpu| |triangular_solve|unimplemented|float16|gpu|
## Table generation ## Table generation

View File

@ -68,6 +68,8 @@ def _sanitize_scope_name(name):
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any TfVal = Any
DType = Any DType = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool: def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)): if isinstance(v, (tf.Tensor, tf.Variable)):
return True return True
@ -123,7 +125,6 @@ def _xla_path_disabled_error(primitive_name: str) -> Exception:
@functools.partial(api_util.api_hook, tag="jax2tf_convert") @functools.partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun: Callable, *, def convert(fun: Callable, *,
polymorphic_shapes: Optional[Sequence[Any]]=None, polymorphic_shapes: Optional[Sequence[Any]]=None,
in_shapes=None, # DEPRECATED
with_gradient=True, enable_xla=True) -> Callable: with_gradient=True, enable_xla=True) -> Callable:
"""Transforms `fun` to be executed by TensorFlow. """Transforms `fun` to be executed by TensorFlow.
@ -222,13 +223,13 @@ def convert(fun: Callable, *,
raise TypeError(msg) raise TypeError(msg)
polymorphic_shapes_ = tuple(polymorphic_shapes) polymorphic_shapes_ = tuple(polymorphic_shapes)
# Expand the in_shapes to match the argument pytree # Expand the polymorphic_shapes to match the argument pytree
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes", polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
in_tree.children()[0], in_tree.children()[0],
polymorphic_shapes_)) polymorphic_shapes_))
# Construct the abstract values for the flat arguments, possibly based on # Construct the abstract values for the flat arguments, possibly based on
# the input shapes and the in_shapes if given. May create new shape # the input shapes and the polymorphic_shapes if given. May create new shape
# variables. # variables.
args_avals_flat, shapeenv = _args_to_avals_and_env(args_flat, args_avals_flat, shapeenv = _args_to_avals_and_env(args_flat,
polymorphic_shapes_flat) polymorphic_shapes_flat)
@ -557,11 +558,16 @@ class TensorFlowTracer(core.Tracer):
elif isinstance(val, (tf.Tensor, tf.Variable)): elif isinstance(val, (tf.Tensor, tf.Variable)):
val_shape, val_dtype = _tfval_shape_dtype(val) val_shape, val_dtype = _tfval_shape_dtype(val)
aval_dtype = np.dtype(self._aval.dtype) # type: ignore[attr-defined] aval_dtype = np.dtype(self._aval.dtype) # type: ignore[attr-defined]
if val_dtype != aval_dtype and (val_dtype == tf.int32 and aval_dtype == jnp.int64 or if (val_dtype != aval_dtype and
val_dtype == tf.int64 and aval_dtype == jnp.int32 or not config.x64_enabled and
val_dtype == tf.float32 and aval_dtype == jnp.float64 or (val_dtype == tf.int32 and aval_dtype == jnp.int64 or
val_dtype == tf.float64 and aval_dtype == jnp.float32): val_dtype == tf.int64 and aval_dtype == jnp.int32 or
# We expect that x64 values are turned into x32 val_dtype == tf.float32 and aval_dtype == jnp.float64 or
val_dtype == tf.float64 and aval_dtype == jnp.float32 or
val_dtype == tf.complex128 and aval_dtype == jnp.complex64)):
# If JAX does not have x64 bit mode enabled, it will force the 64-bit
# values to use 32-bit precision. In order to make the TF conversion
# follow JAX's rules, we cast the TF values down to 32-bit mode.
val = tf.cast(val, dtype=aval_dtype) val = tf.cast(val, dtype=aval_dtype)
val_dtype = aval_dtype val_dtype = aval_dtype
@ -569,7 +575,8 @@ class TensorFlowTracer(core.Tracer):
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}" assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined] for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None: if val_dim is None:
assert isinstance(aval_dim, shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined] assert isinstance(aval_dim,
shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
elif not isinstance(aval_dim, shape_poly.DimVar): elif not isinstance(aval_dim, shape_poly.DimVar):
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined] assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
else: else:
@ -1082,13 +1089,14 @@ def _conv_general_dimension_numbers_proto(dimension_numbers):
return proto return proto
def _conv_general_precision_config_proto(precision): def _precision_config_proto(precision: Optional[Tuple[PrecisionType, PrecisionType]]):
"""Convert an integer to an XLA.PrecisionConfig.""" """Convert an integer to an XLA.PrecisionConfig."""
if precision is None: if precision is None:
return None return None
proto = xla_data_pb2.PrecisionConfig() proto = xla_data_pb2.PrecisionConfig()
proto.operand_precision.append(int(precision)) proto.operand_precision.append(int(precision[0]))
proto.operand_precision.append(int(precision[1]))
return proto return proto
# _try_tf_conv returns a Tensor when it succeeds, or a string describing why # _try_tf_conv returns a Tensor when it succeeds, or a string describing why
@ -1196,9 +1204,11 @@ def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
return error return error
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums) return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, def _conv_general_dilated(lhs, rhs, *,
window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, lhs_shape, rhs_shape, precision, batch_group_count, lhs_shape, rhs_shape,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type, _in_avals, _out_aval): preferred_element_type, _in_avals, _out_aval):
"""Implementation of lax.conv_general_dilated_p using XlaConv.""" """Implementation of lax.conv_general_dilated_p using XlaConv."""
if not _enable_xla: if not _enable_xla:
@ -1212,7 +1222,7 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
raise _xla_path_disabled_error("conv_general_dilated") raise _xla_path_disabled_error("conv_general_dilated")
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers) dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
precision_config_proto = _conv_general_precision_config_proto(precision) precision_config_proto = _precision_config_proto(precision)
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count
out = tfxla.conv( out = tfxla.conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
@ -1226,24 +1236,38 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type): def _dot_general(lhs, rhs, *,
dimension_numbers,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" """Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
del precision
del preferred_element_type
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape) lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
if _enable_xla:
dnums_proto = xla_data_pb2.DotDimensionNumbers()
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
precision_config_proto = _precision_config_proto(precision)
res = tfxla.dot_general(lhs, rhs, dnums_proto, precision_config_proto,
preferred_element_type=preferred_element_type)
# TODO: in presence of None dimensions, XlaDot shape inference returns
# unknown shape.
res.set_shape(_aval_to_tf_shape(_out_aval))
return res
# This condition ensures that: # This condition ensures that:
# 1) the considered dtype is not tf.bfloat16/tf.int32, which are supported by # 1) the batch dimensions are ordered in the same way in lhs and rhs (this is
# tf.linalg.einsum but not by tf.linalg.matmul;
# 2) the batch dimensions are ordered in the same way in lhs and rhs (this is
# not strictly necessary, but we would have to reshape the array if that # not strictly necessary, but we would have to reshape the array if that
# were not the case; # were not the case;
# 3) lhs and rhs have the same number of dimensions +/- 1 # 2) lhs and rhs have the same number of dimensions +/- 1
# 4) the number of non-batch dimensions in both tensors is either 1 or 2 # 3) the number of non-batch dimensions in both tensors is either 1 or 2
# 5) the contracting dimensions are consistent with those of a classic # 4) the contracting dimensions are consistent with those of a classic
# matrix/matrix, vector/matrix or matrix/vector multiplication. # matrix/matrix, vector/matrix or matrix/vector multiplication.
if (not lhs.dtype in [tf.bfloat16, tf.int32] if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
and lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
and lhs_ndim - rhs_ndim in [-1, 0, 1] and lhs_ndim - rhs_ndim in [-1, 0, 1]
and 1 <= lhs_ndim - len(lhs_batch) <= 2 and 1 <= lhs_ndim - len(lhs_batch) <= 2
and 1 <= rhs_ndim - len(rhs_batch) <= 2 and 1 <= rhs_ndim - len(rhs_batch) <= 2
@ -1290,16 +1314,16 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type)
shared_id = next(new_id) shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids = [] batch_ids = []
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch): for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
shared_id = next(new_id) shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids.append(shared_id) batch_ids.append(shared_id)
not_none = lambda x: x is not None not_none = lambda x: x is not None
@ -1310,7 +1334,7 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type)
"".join(rhs_axis_ids), "".join(rhs_axis_ids),
"".join(out_axis_ids)) "".join(out_axis_ids))
return tf.linalg.einsum(spec, lhs, rhs) return tf.linalg.einsum(spec, lhs, rhs)
tf_impl[lax.dot_general_p] = _dot_general tf_impl_with_avals[lax.dot_general_p] = _dot_general
def _broadcast(operand, *, sizes): def _broadcast(operand, *, sizes):

View File

@ -460,17 +460,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes="compiled", modes="compiled",
# Works for 2D matrices. # Works for 2D matrices.
enabled=(len(harness.params["lhs_shape"]) > 2)), enabled=(len(harness.params["lhs_shape"]) > 2)),
custom_numeric(dtypes=dtypes.bfloat16, tol=0.3),
custom_numeric(
dtypes=[np.complex64, np.float32], devices=("cpu", "gpu"),
tol=1e-5),
custom_numeric(
dtypes=[np.complex128, np.float64], devices=("cpu", "gpu"),
tol=1e-12),
custom_numeric(dtypes=np.float32, devices="tpu", tol=0.1),
custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.3),
custom_numeric(dtypes=np.float16, devices=("gpu", "tpu"), tol=0.1),
custom_numeric(dtypes=np.float16, devices="cpu", tol=0.01)
] ]
@classmethod @classmethod

View File

@ -2352,22 +2352,49 @@ def _make_dot_general_harness(name,
rhs_shape=(4, 2), rhs_shape=(4, 2),
dtype=np.float32, dtype=np.float32,
precision=None, precision=None,
dimension_numbers=(((1,), (0,)), ((), ()))): dimension_numbers=(((1,), (0,)), ((), ())),
preferred_element_type=None):
suffix = ""
if precision is not None:
suffix += f"_precision={precision}"
if preferred_element_type is not None:
suffix += f"_preferred={jtu.dtype_str(preferred_element_type)}"
define( define(
lax.dot_general_p, lax.dot_general_p,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}" f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}{suffix}"
.replace(" ", ""), .replace(" ", ""),
lax.dot_general, [ lax.dot_general, [
RandArg(lhs_shape, dtype), RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype), RandArg(rhs_shape, dtype),
StaticArg(dimension_numbers), StaticArg(dimension_numbers),
StaticArg(precision) StaticArg(precision),
StaticArg(preferred_element_type)
], ],
dtype=dtype, dtype=dtype,
lhs_shape=lhs_shape, lhs_shape=lhs_shape,
rhs_shape=rhs_shape, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers, dimension_numbers=dimension_numbers,
precision=precision) precision=precision,
preferred_element_type=preferred_element_type,
jax_unimplemented=[
Limitation(
"preferred_element_type=c128 not implemented",
devices="tpu",
dtypes=np.complex64,
enabled=(preferred_element_type in [np.complex128])),
Limitation(
"preferred_element_type=f64 crashes (b/187884887)",
devices="tpu",
dtypes=(np.float16, jnp.bfloat16, np.float32),
enabled=(preferred_element_type in [np.float64]),
skip_run=True),
Limitation(
"preferred_element_type=i64 not implemented",
devices="tpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int64])),
],
)
# There are two execution paths in the conversion of dot_general. The main path # There are two execution paths in the conversion of dot_general. The main path
@ -2422,6 +2449,28 @@ for lhs_shape, rhs_shape, dimension_numbers in [
rhs_shape=rhs_shape, rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers) dimension_numbers=dimension_numbers)
# Validate preferred element type
# From lax_test.py
preferred_type_combinations = [
(np.float16, np.float16), (np.float16, np.float32), (np.float16, np.float64),
(dtypes.bfloat16, np.float32),
(dtypes.bfloat16, np.float64), (np.float32, np.float32), (np.float32, np.float64),
(np.int8, np.int16), (np.int8, np.int32),
(np.int8, np.int64), (np.int16, np.int32), (np.int16, np.int64),
(np.int32, np.int32), (np.int32, np.int64),
(np.complex64, np.complex128)]
for lhs_shape in [(3,), (4, 3)]:
for rhs_shape in [(3, ), (3, 6)]:
for dtype, preferred_element_type in preferred_type_combinations:
_make_dot_general_harness(
"preferred",
dtype=dtype,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=(((len(lhs_shape) - 1,), (0,)), ((), ())),
preferred_element_type=preferred_element_type)
def _make_concatenate_harness(name, def _make_concatenate_harness(name,
*, *,
@ -2521,8 +2570,6 @@ for batch_group_count, feature_group_count in [
feature_group_count=feature_group_count, feature_group_count=feature_group_count,
batch_group_count=batch_group_count) batch_group_count=batch_group_count)
### XXX
# Validate variations of window_strides # Validate variations of window_strides
for window_strides in [(2, 3)]: for window_strides in [(2, 3)]:
_make_conv_harness("window_strides", window_strides=window_strides) _make_conv_harness("window_strides", window_strides=window_strides)

View File

@ -99,7 +99,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# If you want to run this test for only one harness, add parameter # If you want to run this test for only one harness, add parameter
# `one_containing="foo"` to parameterized below. # `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized( @primitive_harness.parameterized(
primitive_harness.all_harnesses, include_jax_unimpl=False primitive_harness.all_harnesses, include_jax_unimpl=False,
) )
@jtu.ignore_warning( @jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*") category=UserWarning, message="Using reduced precision for gradient.*")

View File

@ -332,11 +332,11 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
tf_value_and_grad, autograph=False).get_concrete_function( tf_value_and_grad, autograph=False).get_concrete_function(
tf.TensorSpec([None, None, 8, 9])) tf.TensorSpec([None, None, 8, 9]))
# The shape of the value # The shape of the value. This should be (None, None, 8, 8) but the
self.assertEqual((None, None, 8, 8), tuple(tf_grad.output_shapes[0])) # shape inference for XlaDot is broken, and returns too many unknown
# The shape of the gradient should match the input # dimensions.
# TODO: there seems to be a bug here, the output should be (None, None, 8, 9) self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0]))
# self.assertEqual((None, None, 8, None), tuple(tf_grad.output_shapes[1])) self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1]))
def test_gradients_pytree(self): def test_gradients_pytree(self):
"""Shape polymorphism with gradients and pytrees for inputs and outputs.""" """Shape polymorphism with gradients and pytrees for inputs and outputs."""
@ -742,12 +742,18 @@ _POLY_SHAPE_TEST_HARNESSES = [
[RandArg((3,), _f32)], [RandArg((3,), _f32)],
poly_axes=[0]), poly_axes=[0]),
_make_harness("jnp_matmul", "", _make_harness("jnp_matmul", "0",
jnp.matmul, jnp.matmul,
[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)], [RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)],
poly_axes=[0, 0], poly_axes=[0, 0],
tol=1e-5), tol=1e-5),
_make_harness("jnp_matmul", "1",
jnp.matmul,
[RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)],
poly_axes=[0, None],
tol=1e-5),
_make_harness("jnp_where", "", _make_harness("jnp_where", "",
jnp.where, jnp.where,
[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)], [RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)],
@ -935,6 +941,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"""Tests for primitives that take shape values as parameters.""" """Tests for primitives that take shape values as parameters."""
# This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES. # This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES.
# For each primitive "xxx" the
# test will be called "test_prim_xxx_...".
# If you want to run this test for only one harness that includes "foo"
# in the name, add parameter `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES) @primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES)
def test_prim(self, harness: Harness): def test_prim(self, harness: Harness):
args = harness.dyn_args_maker(self.rng()) args = harness.dyn_args_maker(self.rng())

View File

@ -36,7 +36,7 @@ DType = Any
def _make_tf_args(args): def _make_tf_args(args):
def _convert_to_tensor(v): def _convert_to_tensor(v):
if hasattr(v, "dtype"): if hasattr(v, "dtype"):
tf.convert_to_tensor(v) return tf.convert_to_tensor(v)
return v return v
return tf.nest.map_structure(_convert_to_tensor, args) return tf.nest.map_structure(_convert_to_tensor, args)
@ -51,7 +51,6 @@ def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]:
return tf.nest.map_structure(_make_one_arg_signature, list(tf_args)) return tf.nest.map_structure(_make_one_arg_signature, list(tf_args))
def _run_tf_function(func_tf: Callable, *tf_args, mode: str): def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
if mode == "eager": if mode == "eager":
return func_tf(*tf_args) # EAGER return func_tf(*tf_args) # EAGER
@ -188,13 +187,47 @@ class JaxToTfTestCase(jtu.JaxTestCase):
custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert] custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert]
assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}" assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}"
if custom_assert_lim: try:
logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}")) if custom_assert_lim:
custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol) logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"))
else: custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol)
logging.info(log_message(f"Running default assert with tol={max_tol}")) else:
# In compiled mode we expect the same result as JAX by default logging.info(log_message(f"Running default assert with tol={max_tol}"))
self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol) self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol)
except AssertionError as e:
# Log the optimized HLO for compiled mode, it should match.
if mode != "compiled":
logging.info(f"[{self._testMethodName}] Not printing HLO because the "
f"mode was {mode}")
raise
logging.info(f"[{self._testMethodName}] Logging HLO "
f"for comparison error {e}")
jax_comp = jax.xla_computation(func_jax)(*args)
jax_hlo = jax_comp.as_hlo_text()
logging.info(f"[{self._testMethodName}] "
f"JAX NON-OPT HLO\n{jax_hlo}")
tf_func_compiled = tf.function(
func_tf,
autograph=False,
jit_compile=True,
input_signature=_make_tf_input_signature(*tf_args))
tf_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)(
stage="hlo")
logging.info(f"[{self._testMethodName}] TF NON-OPT HLO\n{tf_hlo}")
backend = jax.lib.xla_bridge.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()
logging.info(f"[{self._testMethodName}] "
f"JAX OPT HLO\n{jax_opt_hlo}")
tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)(
stage="optimized_hlo")
logging.info(f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}")
raise
# end "for mode" # end "for mode"

View File

@ -784,7 +784,11 @@ def assert_dot_precision(expected_precision, fun, *args):
if eqn.primitive == lax.dot_general_p] if eqn.primitive == lax.dot_general_p]
for precision in precisions: for precision in precisions:
msg = "Unexpected precision: {} != {}".format(expected_precision, precision) msg = "Unexpected precision: {} != {}".format(expected_precision, precision)
assert precision == expected_precision, msg if isinstance(precision, tuple):
assert precision[0] == expected_precision, msg
assert precision[1] == expected_precision, msg
else:
assert precision == expected_precision, msg
_CACHED_INDICES: Dict[int, Sequence[int]] = {} _CACHED_INDICES: Dict[int, Sequence[int]] = {}

View File

@ -2591,23 +2591,23 @@ class APITest(jtu.JaxTestCase):
with jax.default_matmul_precision("bfloat16"): with jax.default_matmul_precision("bfloat16"):
x @ x # doesn't crash x @ x # doesn't crash
jaxpr = jax.make_jaxpr(op.matmul)(x, x) jaxpr = jax.make_jaxpr(op.matmul)(x, x)
self.assertIn('precision=DEFAULT', str(jaxpr)) self.assertIn('Precision.DEFAULT', str(jaxpr))
with jax.default_matmul_precision("tensorfloat32"): with jax.default_matmul_precision("tensorfloat32"):
jnp.dot(x, x) # doesn't crash jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x) jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('precision=HIGH\n', str(jaxpr)) self.assertIn('Precision.HIGH', str(jaxpr))
with jax.default_matmul_precision("float32"): with jax.default_matmul_precision("float32"):
jnp.dot(x, x) # doesn't crash jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x) jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('precision=HIGHEST', str(jaxpr)) self.assertIn('Precision.HIGHEST', str(jaxpr))
dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
with jax.default_matmul_precision("tensorfloat32"): with jax.default_matmul_precision("tensorfloat32"):
dot(x, x) # doesn't crash dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(dot)(x, x) jaxpr = jax.make_jaxpr(dot)(x, x)
self.assertIn('precision=HIGHEST', str(jaxpr)) self.assertIn('Precision.HIGHEST', str(jaxpr))
def test_dot_precision_flag(self): def test_dot_precision_flag(self):
x = jnp.zeros((2, 2)) x = jnp.zeros((2, 2))
@ -2619,7 +2619,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jnp.dot)(x, x) jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally: finally:
config.FLAGS.jax_default_matmul_precision = prev_val config.FLAGS.jax_default_matmul_precision = prev_val
self.assertIn('precision=HIGH', str(jaxpr)) self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
prev_val = config._read("jax_default_matmul_precision") prev_val = config._read("jax_default_matmul_precision")
@ -2629,7 +2629,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jnp.dot)(x, x) jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally: finally:
config.update('jax_default_matmul_precision', prev_val) config.update('jax_default_matmul_precision', prev_val)
self.assertIn('precision=HIGH', str(jaxpr)) self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
@unittest.skipIf(jax.lib._xla_extension_version <= 17, @unittest.skipIf(jax.lib._xla_extension_version <= 17,

View File

@ -398,7 +398,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
result, pullback = api.vjp(dot, lhs, rhs) result, pullback = api.vjp(dot, lhs, rhs)
gresult = lax.zeros_like_array(result) gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult)) s = str(api.make_jaxpr(pullback)(gresult))
assert "precision=HIGHEST" in s assert "Precision.HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": {"testcase_name":
@ -429,7 +429,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
result, pullback = api.vjp(dot_general, lhs, rhs) result, pullback = api.vjp(dot_general, lhs, rhs)
gresult = lax.zeros_like_array(result) gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult)) s = str(api.make_jaxpr(pullback)(gresult))
assert "precision=HIGHEST" in s assert "Precision.HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list( @parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(