Copybara import of the project:

--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
This commit is contained in:
George Necula 2021-05-12 02:29:51 -07:00 committed by jax authors
parent aa74314c1a
commit 235eb8c2b4
15 changed files with 212 additions and 90 deletions

View File

@ -11,9 +11,15 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.14 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
* Bug fixes:
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
({jax-issue}`#6717`).
## jaxlib 0.1.67 (unreleased)
## jaxlib 0.1.66 (May 11 2021)
* New features:
* CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.

View File

@ -6652,28 +6652,35 @@ def remaining(original, *removed_lists):
return [i for i in original if i not in removed]
def _canonicalize_precision(precision):
def _canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[PrecisionType, PrecisionType]]:
"""Turns an API precision specification, into a pair of enumeration values.
The API can take the precision as a string, or int, and either as a single
value to apply to both operands, or as a sequence of two values.
"""
if precision is None:
if config.jax_default_matmul_precision is None:
return None
try:
return _precision_strings[config.jax_default_matmul_precision]
precision = _precision_strings[config.jax_default_matmul_precision]
return (precision, precision)
except KeyError:
raise ValueError(
"jax_default_matmul_precision flag must be set to None or a value in "
f"{_precision_strings}, but got {config.jax_default_matmul_precision}"
) from None
elif isinstance(precision, str) and precision in _precision_strings:
return _precision_strings.get(precision)
precision = _precision_strings.get(precision)
return (precision, precision)
elif isinstance(precision, Precision):
return precision
return (precision, precision)
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(p, Precision) for p in precision)):
return precision
return precision # type: ignore[return-value]
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
all(isinstance(s, str) for s in precision)):
s1, s2 = precision
return (_canonicalize_precision(s1), _canonicalize_precision(s2))
return (_canonicalize_precision(s1)[0], _canonicalize_precision(s2)[0]) # type: ignore
else:
raise ValueError(
f"Precision argument must be None, a string in {_precision_strings}, "

View File

@ -456,6 +456,7 @@ We use the following TFXLA ops:
* `XlaPad` (wraps XLA Pad operator). We use this instead of `tf.pad` in order to
support `lax.pad` interior padding (dilation) or negative edge padding.
* `XlaConv` (wraps XLA ConvGeneralDilated operator).
* `XlaDot` and `XlaDotV2` (wraps XLA DotGeneral operator).
* `XlaGather` (wraps XLA Gather operator). We could use `tf.gather` in some
cases but not always. Also, `tf.gather` has a different semantics than `lax.gather`
for index out of bounds.

View File

@ -15,7 +15,7 @@
See README.md for instructions.
"""
import grpc
import grpc # type: ignore[import]
import json
import logging
import requests
@ -26,9 +26,9 @@ from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib # type: ignore
import numpy as np
import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore
from tensorflow_serving.apis import predict_pb2
import tensorflow as tf # type: ignore[import]
import tensorflow_datasets as tfds # type: ignore[import]
from tensorflow_serving.apis import predict_pb2 # type: ignore[import]
from tensorflow_serving.apis import prediction_service_pb2_grpc

View File

@ -21,8 +21,8 @@ from jax.experimental.jax2tf.examples import mnist_lib
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow as tf # type: ignore[import]
import tensorflow_datasets as tfds # type: ignore[import]
flags.DEFINE_string('tflite_file_path',
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite',

View File

@ -1,11 +1,11 @@
# Primitives with limited JAX support
*Last generated on: 2021-03-29* (YYYY-MM-DD)
*Last generated on: 2021-05-12* (YYYY-MM-DD)
## Supported data types for primitives
We use a set of 2313 test harnesses to test
the implementation of 122 numeric JAX primitives.
We use a set of 2418 test harnesses to test
the implementation of 121 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
The following table shows the dtypes at which primitives
@ -76,7 +76,7 @@ be updated.
| device_put | 16 | all | |
| digamma | 4 | floating | bool, complex, integer |
| div | 20 | inexact, integer | bool |
| dot_general | 125 | all | |
| dot_general | 245 | all | |
| dynamic_slice | 32 | all | |
| dynamic_update_slice | 21 | all | |
| eig | 72 | inexact | bool, integer |
@ -156,7 +156,6 @@ be updated.
| svd | 120 | inexact | bool, integer |
| tan | 6 | inexact | bool, integer |
| tanh | 6 | inexact | bool, integer |
| tie_in | 15 | all | |
| top_k | 15 | bool, floating, integer | complex |
| transpose | 17 | all | |
| triangular_solve | 26 | inexact | bool, integer |
@ -188,6 +187,9 @@ and search for "limitation".
|cummax|unimplemented|complex64|tpu|
|cummin|unimplemented|complex64|tpu|
|cumprod|unimplemented|complex64|tpu|
|dot_general|preferred_element_type=c128 not implemented|complex64|tpu|
|dot_general|preferred_element_type=f64 crashes (b/187884887)|bfloat16, float16, float32|tpu|
|dot_general|preferred_element_type=i64 not implemented|int16, int32, int8|tpu|
|eig|only supported on CPU in JAX|all|tpu, gpu|
|eig|unimplemented|bfloat16, float16|cpu|
|eigh|complex eigh not supported |complex|tpu|
@ -202,7 +204,6 @@ and search for "limitation".
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|svd|unimplemented|bfloat16, float16|cpu, gpu|
|tie_in|requires omnistaging to be disabled|all|cpu, gpu, tpu|
|triangular_solve|unimplemented|float16|gpu|
## Table generation

View File

@ -68,6 +68,8 @@ def _sanitize_scope_name(name):
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
DType = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
@ -123,7 +125,6 @@ def _xla_path_disabled_error(primitive_name: str) -> Exception:
@functools.partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun: Callable, *,
polymorphic_shapes: Optional[Sequence[Any]]=None,
in_shapes=None, # DEPRECATED
with_gradient=True, enable_xla=True) -> Callable:
"""Transforms `fun` to be executed by TensorFlow.
@ -222,13 +223,13 @@ def convert(fun: Callable, *,
raise TypeError(msg)
polymorphic_shapes_ = tuple(polymorphic_shapes)
# Expand the in_shapes to match the argument pytree
# Expand the polymorphic_shapes to match the argument pytree
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
in_tree.children()[0],
polymorphic_shapes_))
# Construct the abstract values for the flat arguments, possibly based on
# the input shapes and the in_shapes if given. May create new shape
# the input shapes and the polymorphic_shapes if given. May create new shape
# variables.
args_avals_flat, shapeenv = _args_to_avals_and_env(args_flat,
polymorphic_shapes_flat)
@ -557,11 +558,16 @@ class TensorFlowTracer(core.Tracer):
elif isinstance(val, (tf.Tensor, tf.Variable)):
val_shape, val_dtype = _tfval_shape_dtype(val)
aval_dtype = np.dtype(self._aval.dtype) # type: ignore[attr-defined]
if val_dtype != aval_dtype and (val_dtype == tf.int32 and aval_dtype == jnp.int64 or
val_dtype == tf.int64 and aval_dtype == jnp.int32 or
val_dtype == tf.float32 and aval_dtype == jnp.float64 or
val_dtype == tf.float64 and aval_dtype == jnp.float32):
# We expect that x64 values are turned into x32
if (val_dtype != aval_dtype and
not config.x64_enabled and
(val_dtype == tf.int32 and aval_dtype == jnp.int64 or
val_dtype == tf.int64 and aval_dtype == jnp.int32 or
val_dtype == tf.float32 and aval_dtype == jnp.float64 or
val_dtype == tf.float64 and aval_dtype == jnp.float32 or
val_dtype == tf.complex128 and aval_dtype == jnp.complex64)):
# If JAX does not have x64 bit mode enabled, it will force the 64-bit
# values to use 32-bit precision. In order to make the TF conversion
# follow JAX's rules, we cast the TF values down to 32-bit mode.
val = tf.cast(val, dtype=aval_dtype)
val_dtype = aval_dtype
@ -569,7 +575,8 @@ class TensorFlowTracer(core.Tracer):
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
if val_dim is None:
assert isinstance(aval_dim, shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
assert isinstance(aval_dim,
shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
elif not isinstance(aval_dim, shape_poly.DimVar):
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
else:
@ -1082,13 +1089,14 @@ def _conv_general_dimension_numbers_proto(dimension_numbers):
return proto
def _conv_general_precision_config_proto(precision):
def _precision_config_proto(precision: Optional[Tuple[PrecisionType, PrecisionType]]):
"""Convert an integer to an XLA.PrecisionConfig."""
if precision is None:
return None
proto = xla_data_pb2.PrecisionConfig()
proto.operand_precision.append(int(precision))
proto.operand_precision.append(int(precision[0]))
proto.operand_precision.append(int(precision[1]))
return proto
# _try_tf_conv returns a Tensor when it succeeds, or a string describing why
@ -1196,9 +1204,11 @@ def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
return error
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
def _conv_general_dilated(lhs, rhs, *,
window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, lhs_shape, rhs_shape, precision,
batch_group_count, lhs_shape, rhs_shape,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type, _in_avals, _out_aval):
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
if not _enable_xla:
@ -1212,7 +1222,7 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
raise _xla_path_disabled_error("conv_general_dilated")
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
precision_config_proto = _conv_general_precision_config_proto(precision)
precision_config_proto = _precision_config_proto(precision)
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count
out = tfxla.conv(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
@ -1226,24 +1236,38 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type):
def _dot_general(lhs, rhs, *,
dimension_numbers,
precision: Optional[Tuple[PrecisionType, PrecisionType]],
preferred_element_type: Optional[DType],
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
del precision
del preferred_element_type
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
if _enable_xla:
dnums_proto = xla_data_pb2.DotDimensionNumbers()
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
precision_config_proto = _precision_config_proto(precision)
res = tfxla.dot_general(lhs, rhs, dnums_proto, precision_config_proto,
preferred_element_type=preferred_element_type)
# TODO: in presence of None dimensions, XlaDot shape inference returns
# unknown shape.
res.set_shape(_aval_to_tf_shape(_out_aval))
return res
# This condition ensures that:
# 1) the considered dtype is not tf.bfloat16/tf.int32, which are supported by
# tf.linalg.einsum but not by tf.linalg.matmul;
# 2) the batch dimensions are ordered in the same way in lhs and rhs (this is
# 1) the batch dimensions are ordered in the same way in lhs and rhs (this is
# not strictly necessary, but we would have to reshape the array if that
# were not the case;
# 3) lhs and rhs have the same number of dimensions +/- 1
# 4) the number of non-batch dimensions in both tensors is either 1 or 2
# 5) the contracting dimensions are consistent with those of a classic
# 2) lhs and rhs have the same number of dimensions +/- 1
# 3) the number of non-batch dimensions in both tensors is either 1 or 2
# 4) the contracting dimensions are consistent with those of a classic
# matrix/matrix, vector/matrix or matrix/vector multiplication.
if (not lhs.dtype in [tf.bfloat16, tf.int32]
and lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
and lhs_ndim - rhs_ndim in [-1, 0, 1]
and 1 <= lhs_ndim - len(lhs_batch) <= 2
and 1 <= rhs_ndim - len(rhs_batch) <= 2
@ -1290,16 +1314,16 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type)
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None
rhs_out_axis_ids[rhs_axis] = None
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids = []
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
shared_id = next(new_id)
lhs_axis_ids[lhs_axis] = shared_id
rhs_axis_ids[rhs_axis] = shared_id
lhs_out_axis_ids[lhs_axis] = None
rhs_out_axis_ids[rhs_axis] = None
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
batch_ids.append(shared_id)
not_none = lambda x: x is not None
@ -1310,7 +1334,7 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type)
"".join(rhs_axis_ids),
"".join(out_axis_ids))
return tf.linalg.einsum(spec, lhs, rhs)
tf_impl[lax.dot_general_p] = _dot_general
tf_impl_with_avals[lax.dot_general_p] = _dot_general
def _broadcast(operand, *, sizes):

View File

@ -460,17 +460,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
modes="compiled",
# Works for 2D matrices.
enabled=(len(harness.params["lhs_shape"]) > 2)),
custom_numeric(dtypes=dtypes.bfloat16, tol=0.3),
custom_numeric(
dtypes=[np.complex64, np.float32], devices=("cpu", "gpu"),
tol=1e-5),
custom_numeric(
dtypes=[np.complex128, np.float64], devices=("cpu", "gpu"),
tol=1e-12),
custom_numeric(dtypes=np.float32, devices="tpu", tol=0.1),
custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.3),
custom_numeric(dtypes=np.float16, devices=("gpu", "tpu"), tol=0.1),
custom_numeric(dtypes=np.float16, devices="cpu", tol=0.01)
]
@classmethod

View File

@ -2352,22 +2352,49 @@ def _make_dot_general_harness(name,
rhs_shape=(4, 2),
dtype=np.float32,
precision=None,
dimension_numbers=(((1,), (0,)), ((), ()))):
dimension_numbers=(((1,), (0,)), ((), ())),
preferred_element_type=None):
suffix = ""
if precision is not None:
suffix += f"_precision={precision}"
if preferred_element_type is not None:
suffix += f"_preferred={jtu.dtype_str(preferred_element_type)}"
define(
lax.dot_general_p,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}"
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}{suffix}"
.replace(" ", ""),
lax.dot_general, [
RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype),
StaticArg(dimension_numbers),
StaticArg(precision)
StaticArg(precision),
StaticArg(preferred_element_type)
],
dtype=dtype,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
precision=precision)
precision=precision,
preferred_element_type=preferred_element_type,
jax_unimplemented=[
Limitation(
"preferred_element_type=c128 not implemented",
devices="tpu",
dtypes=np.complex64,
enabled=(preferred_element_type in [np.complex128])),
Limitation(
"preferred_element_type=f64 crashes (b/187884887)",
devices="tpu",
dtypes=(np.float16, jnp.bfloat16, np.float32),
enabled=(preferred_element_type in [np.float64]),
skip_run=True),
Limitation(
"preferred_element_type=i64 not implemented",
devices="tpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int64])),
],
)
# There are two execution paths in the conversion of dot_general. The main path
@ -2422,6 +2449,28 @@ for lhs_shape, rhs_shape, dimension_numbers in [
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers)
# Validate preferred element type
# From lax_test.py
preferred_type_combinations = [
(np.float16, np.float16), (np.float16, np.float32), (np.float16, np.float64),
(dtypes.bfloat16, np.float32),
(dtypes.bfloat16, np.float64), (np.float32, np.float32), (np.float32, np.float64),
(np.int8, np.int16), (np.int8, np.int32),
(np.int8, np.int64), (np.int16, np.int32), (np.int16, np.int64),
(np.int32, np.int32), (np.int32, np.int64),
(np.complex64, np.complex128)]
for lhs_shape in [(3,), (4, 3)]:
for rhs_shape in [(3, ), (3, 6)]:
for dtype, preferred_element_type in preferred_type_combinations:
_make_dot_general_harness(
"preferred",
dtype=dtype,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=(((len(lhs_shape) - 1,), (0,)), ((), ())),
preferred_element_type=preferred_element_type)
def _make_concatenate_harness(name,
*,
@ -2521,8 +2570,6 @@ for batch_group_count, feature_group_count in [
feature_group_count=feature_group_count,
batch_group_count=batch_group_count)
### XXX
# Validate variations of window_strides
for window_strides in [(2, 3)]:
_make_conv_harness("window_strides", window_strides=window_strides)

View File

@ -99,7 +99,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# If you want to run this test for only one harness, add parameter
# `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized(
primitive_harness.all_harnesses, include_jax_unimpl=False
primitive_harness.all_harnesses, include_jax_unimpl=False,
)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")

View File

@ -332,11 +332,11 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
tf_value_and_grad, autograph=False).get_concrete_function(
tf.TensorSpec([None, None, 8, 9]))
# The shape of the value
self.assertEqual((None, None, 8, 8), tuple(tf_grad.output_shapes[0]))
# The shape of the gradient should match the input
# TODO: there seems to be a bug here, the output should be (None, None, 8, 9)
# self.assertEqual((None, None, 8, None), tuple(tf_grad.output_shapes[1]))
# The shape of the value. This should be (None, None, 8, 8) but the
# shape inference for XlaDot is broken, and returns too many unknown
# dimensions.
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0]))
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1]))
def test_gradients_pytree(self):
"""Shape polymorphism with gradients and pytrees for inputs and outputs."""
@ -742,12 +742,18 @@ _POLY_SHAPE_TEST_HARNESSES = [
[RandArg((3,), _f32)],
poly_axes=[0]),
_make_harness("jnp_matmul", "",
_make_harness("jnp_matmul", "0",
jnp.matmul,
[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)],
poly_axes=[0, 0],
tol=1e-5),
_make_harness("jnp_matmul", "1",
jnp.matmul,
[RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)],
poly_axes=[0, None],
tol=1e-5),
_make_harness("jnp_where", "",
jnp.where,
[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)],
@ -935,6 +941,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
"""Tests for primitives that take shape values as parameters."""
# This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES.
# For each primitive "xxx" the
# test will be called "test_prim_xxx_...".
# If you want to run this test for only one harness that includes "foo"
# in the name, add parameter `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES)
def test_prim(self, harness: Harness):
args = harness.dyn_args_maker(self.rng())

View File

@ -36,7 +36,7 @@ DType = Any
def _make_tf_args(args):
def _convert_to_tensor(v):
if hasattr(v, "dtype"):
tf.convert_to_tensor(v)
return tf.convert_to_tensor(v)
return v
return tf.nest.map_structure(_convert_to_tensor, args)
@ -51,7 +51,6 @@ def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]:
return tf.nest.map_structure(_make_one_arg_signature, list(tf_args))
def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
if mode == "eager":
return func_tf(*tf_args) # EAGER
@ -188,13 +187,47 @@ class JaxToTfTestCase(jtu.JaxTestCase):
custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert]
assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}"
if custom_assert_lim:
logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"))
custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol)
else:
logging.info(log_message(f"Running default assert with tol={max_tol}"))
# In compiled mode we expect the same result as JAX by default
self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol)
try:
if custom_assert_lim:
logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"))
custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol)
else:
logging.info(log_message(f"Running default assert with tol={max_tol}"))
self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol)
except AssertionError as e:
# Log the optimized HLO for compiled mode, it should match.
if mode != "compiled":
logging.info(f"[{self._testMethodName}] Not printing HLO because the "
f"mode was {mode}")
raise
logging.info(f"[{self._testMethodName}] Logging HLO "
f"for comparison error {e}")
jax_comp = jax.xla_computation(func_jax)(*args)
jax_hlo = jax_comp.as_hlo_text()
logging.info(f"[{self._testMethodName}] "
f"JAX NON-OPT HLO\n{jax_hlo}")
tf_func_compiled = tf.function(
func_tf,
autograph=False,
jit_compile=True,
input_signature=_make_tf_input_signature(*tf_args))
tf_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)(
stage="hlo")
logging.info(f"[{self._testMethodName}] TF NON-OPT HLO\n{tf_hlo}")
backend = jax.lib.xla_bridge.get_backend()
modules = backend.compile(jax_comp).hlo_modules()
jax_opt_hlo = modules[0].to_string()
logging.info(f"[{self._testMethodName}] "
f"JAX OPT HLO\n{jax_opt_hlo}")
tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)(
stage="optimized_hlo")
logging.info(f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}")
raise
# end "for mode"

View File

@ -784,7 +784,11 @@ def assert_dot_precision(expected_precision, fun, *args):
if eqn.primitive == lax.dot_general_p]
for precision in precisions:
msg = "Unexpected precision: {} != {}".format(expected_precision, precision)
assert precision == expected_precision, msg
if isinstance(precision, tuple):
assert precision[0] == expected_precision, msg
assert precision[1] == expected_precision, msg
else:
assert precision == expected_precision, msg
_CACHED_INDICES: Dict[int, Sequence[int]] = {}

View File

@ -2591,23 +2591,23 @@ class APITest(jtu.JaxTestCase):
with jax.default_matmul_precision("bfloat16"):
x @ x # doesn't crash
jaxpr = jax.make_jaxpr(op.matmul)(x, x)
self.assertIn('precision=DEFAULT', str(jaxpr))
self.assertIn('Precision.DEFAULT', str(jaxpr))
with jax.default_matmul_precision("tensorfloat32"):
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('precision=HIGH\n', str(jaxpr))
self.assertIn('Precision.HIGH', str(jaxpr))
with jax.default_matmul_precision("float32"):
jnp.dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
self.assertIn('precision=HIGHEST', str(jaxpr))
self.assertIn('Precision.HIGHEST', str(jaxpr))
dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
with jax.default_matmul_precision("tensorfloat32"):
dot(x, x) # doesn't crash
jaxpr = jax.make_jaxpr(dot)(x, x)
self.assertIn('precision=HIGHEST', str(jaxpr))
self.assertIn('Precision.HIGHEST', str(jaxpr))
def test_dot_precision_flag(self):
x = jnp.zeros((2, 2))
@ -2619,7 +2619,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally:
config.FLAGS.jax_default_matmul_precision = prev_val
self.assertIn('precision=HIGH', str(jaxpr))
self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
prev_val = config._read("jax_default_matmul_precision")
@ -2629,7 +2629,7 @@ class APITest(jtu.JaxTestCase):
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
finally:
config.update('jax_default_matmul_precision', prev_val)
self.assertIn('precision=HIGH', str(jaxpr))
self.assertIn('Precision.HIGH', str(jaxpr))
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
@unittest.skipIf(jax.lib._xla_extension_version <= 17,

View File

@ -398,7 +398,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
result, pullback = api.vjp(dot, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult))
assert "precision=HIGHEST" in s
assert "Precision.HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
@ -429,7 +429,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
result, pullback = api.vjp(dot_general, lhs, rhs)
gresult = lax.zeros_like_array(result)
s = str(api.make_jaxpr(pullback)(gresult))
assert "precision=HIGHEST" in s
assert "Precision.HIGHEST" in s
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(