From 235eb8c2b45199d25bdcee5ab5bf5b187a949fe6 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 12 May 2021 02:29:51 -0700 Subject: [PATCH] Copybara import of the project: -- 1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula : [jax2tf] Change the conversion of dot_general to use XLA op. Instead of converting the dot_general to a sea of TF ops, when we enable_xla we just use the XLA op. This has the advantage that it also supports the preferred_element_type. Fixed bug with passing the precision parameter to TF. Also improved tests to print the HLO in case of numerical errors. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366 PiperOrigin-RevId: 373326655 --- CHANGELOG.md | 6 ++ jax/_src/lax/lax.py | 19 +++-- jax/experimental/jax2tf/README.md | 1 + .../examples/serving/model_server_request.py | 8 +- .../jax2tf/examples/tflite/mnist/mnist.py | 4 +- .../jax2tf/g3doc/jax_primitives_coverage.md | 13 +-- jax/experimental/jax2tf/jax2tf.py | 84 ++++++++++++------- .../jax2tf/tests/jax2tf_limitations.py | 11 --- .../jax2tf/tests/primitive_harness.py | 59 +++++++++++-- .../jax2tf/tests/primitives_test.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 22 +++-- jax/experimental/jax2tf/tests/tf_test_util.py | 51 +++++++++-- jax/test_util.py | 6 +- tests/api_test.py | 12 +-- tests/lax_autodiff_test.py | 4 +- 15 files changed, 212 insertions(+), 90 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7b3e6877..59440cbf7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,15 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.14 (unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master). +* Bug fixes: + * The {func}`jax2tf.convert` now converts `lax.dot_general` using the + `XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision + ({jax-issue}`#6717`). + ## jaxlib 0.1.67 (unreleased) ## jaxlib 0.1.66 (May 11 2021) + * New features: * CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher. diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 18772f198..e3f82ca17 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -6652,28 +6652,35 @@ def remaining(original, *removed_lists): return [i for i in original if i not in removed] -def _canonicalize_precision(precision): +def _canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[PrecisionType, PrecisionType]]: + """Turns an API precision specification, into a pair of enumeration values. + + The API can take the precision as a string, or int, and either as a single + value to apply to both operands, or as a sequence of two values. + """ if precision is None: if config.jax_default_matmul_precision is None: return None try: - return _precision_strings[config.jax_default_matmul_precision] + precision = _precision_strings[config.jax_default_matmul_precision] + return (precision, precision) except KeyError: raise ValueError( "jax_default_matmul_precision flag must be set to None or a value in " f"{_precision_strings}, but got {config.jax_default_matmul_precision}" ) from None elif isinstance(precision, str) and precision in _precision_strings: - return _precision_strings.get(precision) + precision = _precision_strings.get(precision) + return (precision, precision) elif isinstance(precision, Precision): - return precision + return (precision, precision) elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and all(isinstance(p, Precision) for p in precision)): - return precision + return precision # type: ignore[return-value] elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and all(isinstance(s, str) for s in precision)): s1, s2 = precision - return (_canonicalize_precision(s1), _canonicalize_precision(s2)) + return (_canonicalize_precision(s1)[0], _canonicalize_precision(s2)[0]) # type: ignore else: raise ValueError( f"Precision argument must be None, a string in {_precision_strings}, " diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index c837075e6..9118c5753 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -456,6 +456,7 @@ We use the following TFXLA ops: * `XlaPad` (wraps XLA Pad operator). We use this instead of `tf.pad` in order to support `lax.pad` interior padding (dilation) or negative edge padding. * `XlaConv` (wraps XLA ConvGeneralDilated operator). + * `XlaDot` and `XlaDotV2` (wraps XLA DotGeneral operator). * `XlaGather` (wraps XLA Gather operator). We could use `tf.gather` in some cases but not always. Also, `tf.gather` has a different semantics than `lax.gather` for index out of bounds. diff --git a/jax/experimental/jax2tf/examples/serving/model_server_request.py b/jax/experimental/jax2tf/examples/serving/model_server_request.py index a17a6a973..e32e16901 100644 --- a/jax/experimental/jax2tf/examples/serving/model_server_request.py +++ b/jax/experimental/jax2tf/examples/serving/model_server_request.py @@ -15,7 +15,7 @@ See README.md for instructions. """ -import grpc +import grpc # type: ignore[import] import json import logging import requests @@ -26,9 +26,9 @@ from absl import flags from jax.experimental.jax2tf.examples import mnist_lib # type: ignore import numpy as np -import tensorflow as tf # type: ignore -import tensorflow_datasets as tfds # type: ignore -from tensorflow_serving.apis import predict_pb2 +import tensorflow as tf # type: ignore[import] +import tensorflow_datasets as tfds # type: ignore[import] +from tensorflow_serving.apis import predict_pb2 # type: ignore[import] from tensorflow_serving.apis import prediction_service_pb2_grpc diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py b/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py index 60b83e305..08b96ac0f 100644 --- a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py +++ b/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py @@ -21,8 +21,8 @@ from jax.experimental.jax2tf.examples import mnist_lib import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds +import tensorflow as tf # type: ignore[import] +import tensorflow_datasets as tfds # type: ignore[import] flags.DEFINE_string('tflite_file_path', '/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite', diff --git a/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md b/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md index 147e44b0d..f08eccc05 100644 --- a/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md +++ b/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md @@ -1,11 +1,11 @@ # Primitives with limited JAX support -*Last generated on: 2021-03-29* (YYYY-MM-DD) +*Last generated on: 2021-05-12* (YYYY-MM-DD) ## Supported data types for primitives -We use a set of 2313 test harnesses to test -the implementation of 122 numeric JAX primitives. +We use a set of 2418 test harnesses to test +the implementation of 121 numeric JAX primitives. We consider a JAX primitive supported for a particular data type if it is supported on at least one device type. The following table shows the dtypes at which primitives @@ -76,7 +76,7 @@ be updated. | device_put | 16 | all | | | digamma | 4 | floating | bool, complex, integer | | div | 20 | inexact, integer | bool | -| dot_general | 125 | all | | +| dot_general | 245 | all | | | dynamic_slice | 32 | all | | | dynamic_update_slice | 21 | all | | | eig | 72 | inexact | bool, integer | @@ -156,7 +156,6 @@ be updated. | svd | 120 | inexact | bool, integer | | tan | 6 | inexact | bool, integer | | tanh | 6 | inexact | bool, integer | -| tie_in | 15 | all | | | top_k | 15 | bool, floating, integer | complex | | transpose | 17 | all | | | triangular_solve | 26 | inexact | bool, integer | @@ -188,6 +187,9 @@ and search for "limitation". |cummax|unimplemented|complex64|tpu| |cummin|unimplemented|complex64|tpu| |cumprod|unimplemented|complex64|tpu| +|dot_general|preferred_element_type=c128 not implemented|complex64|tpu| +|dot_general|preferred_element_type=f64 crashes (b/187884887)|bfloat16, float16, float32|tpu| +|dot_general|preferred_element_type=i64 not implemented|int16, int32, int8|tpu| |eig|only supported on CPU in JAX|all|tpu, gpu| |eig|unimplemented|bfloat16, float16|cpu| |eigh|complex eigh not supported |complex|tpu| @@ -202,7 +204,6 @@ and search for "limitation". |select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu| |svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu| |svd|unimplemented|bfloat16, float16|cpu, gpu| -|tie_in|requires omnistaging to be disabled|all|cpu, gpu, tpu| |triangular_solve|unimplemented|float16|gpu| ## Table generation diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index ce0836406..5ea7fe27a 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -68,6 +68,8 @@ def _sanitize_scope_name(name): # or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.) TfVal = Any DType = Any +PrecisionType = int # Enum xla_data.PrecisionConfig.Precision + def _is_tfval(v: TfVal) -> bool: if isinstance(v, (tf.Tensor, tf.Variable)): return True @@ -123,7 +125,6 @@ def _xla_path_disabled_error(primitive_name: str) -> Exception: @functools.partial(api_util.api_hook, tag="jax2tf_convert") def convert(fun: Callable, *, polymorphic_shapes: Optional[Sequence[Any]]=None, - in_shapes=None, # DEPRECATED with_gradient=True, enable_xla=True) -> Callable: """Transforms `fun` to be executed by TensorFlow. @@ -222,13 +223,13 @@ def convert(fun: Callable, *, raise TypeError(msg) polymorphic_shapes_ = tuple(polymorphic_shapes) - # Expand the in_shapes to match the argument pytree + # Expand the polymorphic_shapes to match the argument pytree polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes", in_tree.children()[0], polymorphic_shapes_)) # Construct the abstract values for the flat arguments, possibly based on - # the input shapes and the in_shapes if given. May create new shape + # the input shapes and the polymorphic_shapes if given. May create new shape # variables. args_avals_flat, shapeenv = _args_to_avals_and_env(args_flat, polymorphic_shapes_flat) @@ -557,11 +558,16 @@ class TensorFlowTracer(core.Tracer): elif isinstance(val, (tf.Tensor, tf.Variable)): val_shape, val_dtype = _tfval_shape_dtype(val) aval_dtype = np.dtype(self._aval.dtype) # type: ignore[attr-defined] - if val_dtype != aval_dtype and (val_dtype == tf.int32 and aval_dtype == jnp.int64 or - val_dtype == tf.int64 and aval_dtype == jnp.int32 or - val_dtype == tf.float32 and aval_dtype == jnp.float64 or - val_dtype == tf.float64 and aval_dtype == jnp.float32): - # We expect that x64 values are turned into x32 + if (val_dtype != aval_dtype and + not config.x64_enabled and + (val_dtype == tf.int32 and aval_dtype == jnp.int64 or + val_dtype == tf.int64 and aval_dtype == jnp.int32 or + val_dtype == tf.float32 and aval_dtype == jnp.float64 or + val_dtype == tf.float64 and aval_dtype == jnp.float32 or + val_dtype == tf.complex128 and aval_dtype == jnp.complex64)): + # If JAX does not have x64 bit mode enabled, it will force the 64-bit + # values to use 32-bit precision. In order to make the TF conversion + # follow JAX's rules, we cast the TF values down to 32-bit mode. val = tf.cast(val, dtype=aval_dtype) val_dtype = aval_dtype @@ -569,7 +575,8 @@ class TensorFlowTracer(core.Tracer): assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}" for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined] if val_dim is None: - assert isinstance(aval_dim, shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined] + assert isinstance(aval_dim, + shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined] elif not isinstance(aval_dim, shape_poly.DimVar): assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined] else: @@ -1082,13 +1089,14 @@ def _conv_general_dimension_numbers_proto(dimension_numbers): return proto -def _conv_general_precision_config_proto(precision): +def _precision_config_proto(precision: Optional[Tuple[PrecisionType, PrecisionType]]): """Convert an integer to an XLA.PrecisionConfig.""" if precision is None: return None proto = xla_data_pb2.PrecisionConfig() - proto.operand_precision.append(int(precision)) + proto.operand_precision.append(int(precision[0])) + proto.operand_precision.append(int(precision[1])) return proto # _try_tf_conv returns a Tensor when it succeeds, or a string describing why @@ -1196,9 +1204,11 @@ def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, return error return convert_dilation_and_compute_result(tf_padding, tf_dim_nums) -def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, +def _conv_general_dilated(lhs, rhs, *, + window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, - batch_group_count, lhs_shape, rhs_shape, precision, + batch_group_count, lhs_shape, rhs_shape, + precision: Optional[Tuple[PrecisionType, PrecisionType]], preferred_element_type, _in_avals, _out_aval): """Implementation of lax.conv_general_dilated_p using XlaConv.""" if not _enable_xla: @@ -1212,7 +1222,7 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, raise _xla_path_disabled_error("conv_general_dilated") dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers) - precision_config_proto = _conv_general_precision_config_proto(precision) + precision_config_proto = _precision_config_proto(precision) assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count out = tfxla.conv( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, @@ -1226,24 +1236,38 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated -def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type): +def _dot_general(lhs, rhs, *, + dimension_numbers, + precision: Optional[Tuple[PrecisionType, PrecisionType]], + preferred_element_type: Optional[DType], + _in_avals: Sequence[core.AbstractValue], + _out_aval: core.AbstractValue): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" - del precision - del preferred_element_type (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape) + if _enable_xla: + dnums_proto = xla_data_pb2.DotDimensionNumbers() + dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting) + dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting) + dnums_proto.lhs_batch_dimensions.extend(lhs_batch) + dnums_proto.rhs_batch_dimensions.extend(rhs_batch) + precision_config_proto = _precision_config_proto(precision) + res = tfxla.dot_general(lhs, rhs, dnums_proto, precision_config_proto, + preferred_element_type=preferred_element_type) + # TODO: in presence of None dimensions, XlaDot shape inference returns + # unknown shape. + res.set_shape(_aval_to_tf_shape(_out_aval)) + return res + # This condition ensures that: - # 1) the considered dtype is not tf.bfloat16/tf.int32, which are supported by - # tf.linalg.einsum but not by tf.linalg.matmul; - # 2) the batch dimensions are ordered in the same way in lhs and rhs (this is + # 1) the batch dimensions are ordered in the same way in lhs and rhs (this is # not strictly necessary, but we would have to reshape the array if that # were not the case; - # 3) lhs and rhs have the same number of dimensions +/- 1 - # 4) the number of non-batch dimensions in both tensors is either 1 or 2 - # 5) the contracting dimensions are consistent with those of a classic + # 2) lhs and rhs have the same number of dimensions +/- 1 + # 3) the number of non-batch dimensions in both tensors is either 1 or 2 + # 4) the contracting dimensions are consistent with those of a classic # matrix/matrix, vector/matrix or matrix/vector multiplication. - if (not lhs.dtype in [tf.bfloat16, tf.int32] - and lhs_batch == rhs_batch == tuple(range(len(lhs_batch))) + if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch))) and lhs_ndim - rhs_ndim in [-1, 0, 1] and 1 <= lhs_ndim - len(lhs_batch) <= 2 and 1 <= rhs_ndim - len(rhs_batch) <= 2 @@ -1290,16 +1314,16 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type) shared_id = next(new_id) lhs_axis_ids[lhs_axis] = shared_id rhs_axis_ids[rhs_axis] = shared_id - lhs_out_axis_ids[lhs_axis] = None - rhs_out_axis_ids[rhs_axis] = None + lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload] + rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload] batch_ids = [] for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch): shared_id = next(new_id) lhs_axis_ids[lhs_axis] = shared_id rhs_axis_ids[rhs_axis] = shared_id - lhs_out_axis_ids[lhs_axis] = None - rhs_out_axis_ids[rhs_axis] = None + lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload] + rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload] batch_ids.append(shared_id) not_none = lambda x: x is not None @@ -1310,7 +1334,7 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type) "".join(rhs_axis_ids), "".join(out_axis_ids)) return tf.linalg.einsum(spec, lhs, rhs) -tf_impl[lax.dot_general_p] = _dot_general +tf_impl_with_avals[lax.dot_general_p] = _dot_general def _broadcast(operand, *, sizes): diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 3e5555f7d..db7729380 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -460,17 +460,6 @@ class Jax2TfLimitation(primitive_harness.Limitation): modes="compiled", # Works for 2D matrices. enabled=(len(harness.params["lhs_shape"]) > 2)), - custom_numeric(dtypes=dtypes.bfloat16, tol=0.3), - custom_numeric( - dtypes=[np.complex64, np.float32], devices=("cpu", "gpu"), - tol=1e-5), - custom_numeric( - dtypes=[np.complex128, np.float64], devices=("cpu", "gpu"), - tol=1e-12), - custom_numeric(dtypes=np.float32, devices="tpu", tol=0.1), - custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.3), - custom_numeric(dtypes=np.float16, devices=("gpu", "tpu"), tol=0.1), - custom_numeric(dtypes=np.float16, devices="cpu", tol=0.01) ] @classmethod diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index 2c1ff31b3..9964a8056 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -2352,22 +2352,49 @@ def _make_dot_general_harness(name, rhs_shape=(4, 2), dtype=np.float32, precision=None, - dimension_numbers=(((1,), (0,)), ((), ()))): + dimension_numbers=(((1,), (0,)), ((), ())), + preferred_element_type=None): + suffix = "" + if precision is not None: + suffix += f"_precision={precision}" + if preferred_element_type is not None: + suffix += f"_preferred={jtu.dtype_str(preferred_element_type)}" define( lax.dot_general_p, - f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}" + f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}{suffix}" .replace(" ", ""), lax.dot_general, [ RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype), StaticArg(dimension_numbers), - StaticArg(precision) + StaticArg(precision), + StaticArg(preferred_element_type) ], dtype=dtype, lhs_shape=lhs_shape, rhs_shape=rhs_shape, dimension_numbers=dimension_numbers, - precision=precision) + precision=precision, + preferred_element_type=preferred_element_type, + jax_unimplemented=[ + Limitation( + "preferred_element_type=c128 not implemented", + devices="tpu", + dtypes=np.complex64, + enabled=(preferred_element_type in [np.complex128])), + Limitation( + "preferred_element_type=f64 crashes (b/187884887)", + devices="tpu", + dtypes=(np.float16, jnp.bfloat16, np.float32), + enabled=(preferred_element_type in [np.float64]), + skip_run=True), + Limitation( + "preferred_element_type=i64 not implemented", + devices="tpu", + dtypes=(np.int8, np.int16, np.int32), + enabled=(preferred_element_type in [np.int64])), + ], + ) # There are two execution paths in the conversion of dot_general. The main path @@ -2422,6 +2449,28 @@ for lhs_shape, rhs_shape, dimension_numbers in [ rhs_shape=rhs_shape, dimension_numbers=dimension_numbers) +# Validate preferred element type +# From lax_test.py +preferred_type_combinations = [ + (np.float16, np.float16), (np.float16, np.float32), (np.float16, np.float64), + (dtypes.bfloat16, np.float32), + (dtypes.bfloat16, np.float64), (np.float32, np.float32), (np.float32, np.float64), + (np.int8, np.int16), (np.int8, np.int32), + (np.int8, np.int64), (np.int16, np.int32), (np.int16, np.int64), + (np.int32, np.int32), (np.int32, np.int64), + (np.complex64, np.complex128)] + +for lhs_shape in [(3,), (4, 3)]: + for rhs_shape in [(3, ), (3, 6)]: + for dtype, preferred_element_type in preferred_type_combinations: + _make_dot_general_harness( + "preferred", + dtype=dtype, + lhs_shape=lhs_shape, + rhs_shape=rhs_shape, + dimension_numbers=(((len(lhs_shape) - 1,), (0,)), ((), ())), + preferred_element_type=preferred_element_type) + def _make_concatenate_harness(name, *, @@ -2521,8 +2570,6 @@ for batch_group_count, feature_group_count in [ feature_group_count=feature_group_count, batch_group_count=batch_group_count) -### XXX - # Validate variations of window_strides for window_strides in [(2, 3)]: _make_conv_harness("window_strides", window_strides=window_strides) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 552102356..21168ef1f 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -99,7 +99,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # If you want to run this test for only one harness, add parameter # `one_containing="foo"` to parameterized below. @primitive_harness.parameterized( - primitive_harness.all_harnesses, include_jax_unimpl=False + primitive_harness.all_harnesses, include_jax_unimpl=False, ) @jtu.ignore_warning( category=UserWarning, message="Using reduced precision for gradient.*") diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index aaad88024..79b4cd64d 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -332,11 +332,11 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): tf_value_and_grad, autograph=False).get_concrete_function( tf.TensorSpec([None, None, 8, 9])) - # The shape of the value - self.assertEqual((None, None, 8, 8), tuple(tf_grad.output_shapes[0])) - # The shape of the gradient should match the input - # TODO: there seems to be a bug here, the output should be (None, None, 8, 9) - # self.assertEqual((None, None, 8, None), tuple(tf_grad.output_shapes[1])) + # The shape of the value. This should be (None, None, 8, 8) but the + # shape inference for XlaDot is broken, and returns too many unknown + # dimensions. + self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0])) + self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1])) def test_gradients_pytree(self): """Shape polymorphism with gradients and pytrees for inputs and outputs.""" @@ -742,12 +742,18 @@ _POLY_SHAPE_TEST_HARNESSES = [ [RandArg((3,), _f32)], poly_axes=[0]), - _make_harness("jnp_matmul", "", + _make_harness("jnp_matmul", "0", jnp.matmul, [RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)], poly_axes=[0, 0], tol=1e-5), + _make_harness("jnp_matmul", "1", + jnp.matmul, + [RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)], + poly_axes=[0, None], + tol=1e-5), + _make_harness("jnp_where", "", jnp.where, [RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)], @@ -935,6 +941,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): """Tests for primitives that take shape values as parameters.""" # This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES. + # For each primitive "xxx" the + # test will be called "test_prim_xxx_...". + # If you want to run this test for only one harness that includes "foo" + # in the name, add parameter `one_containing="foo"` to parameterized below. @primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES) def test_prim(self, harness: Harness): args = harness.dyn_args_maker(self.rng()) diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 116c662f0..7bd8572d9 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -36,7 +36,7 @@ DType = Any def _make_tf_args(args): def _convert_to_tensor(v): if hasattr(v, "dtype"): - tf.convert_to_tensor(v) + return tf.convert_to_tensor(v) return v return tf.nest.map_structure(_convert_to_tensor, args) @@ -51,7 +51,6 @@ def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]: return tf.nest.map_structure(_make_one_arg_signature, list(tf_args)) - def _run_tf_function(func_tf: Callable, *tf_args, mode: str): if mode == "eager": return func_tf(*tf_args) # EAGER @@ -188,13 +187,47 @@ class JaxToTfTestCase(jtu.JaxTestCase): custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert] assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}" - if custom_assert_lim: - logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}")) - custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol) - else: - logging.info(log_message(f"Running default assert with tol={max_tol}")) - # In compiled mode we expect the same result as JAX by default - self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol) + try: + if custom_assert_lim: + logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}")) + custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol) + else: + logging.info(log_message(f"Running default assert with tol={max_tol}")) + self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol) + except AssertionError as e: + # Log the optimized HLO for compiled mode, it should match. + if mode != "compiled": + logging.info(f"[{self._testMethodName}] Not printing HLO because the " + f"mode was {mode}") + raise + + logging.info(f"[{self._testMethodName}] Logging HLO " + f"for comparison error {e}") + + jax_comp = jax.xla_computation(func_jax)(*args) + jax_hlo = jax_comp.as_hlo_text() + logging.info(f"[{self._testMethodName}] " + f"JAX NON-OPT HLO\n{jax_hlo}") + + tf_func_compiled = tf.function( + func_tf, + autograph=False, + jit_compile=True, + input_signature=_make_tf_input_signature(*tf_args)) + tf_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)( + stage="hlo") + logging.info(f"[{self._testMethodName}] TF NON-OPT HLO\n{tf_hlo}") + + backend = jax.lib.xla_bridge.get_backend() + modules = backend.compile(jax_comp).hlo_modules() + jax_opt_hlo = modules[0].to_string() + logging.info(f"[{self._testMethodName}] " + f"JAX OPT HLO\n{jax_opt_hlo}") + + tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)( + stage="optimized_hlo") + logging.info(f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}") + raise # end "for mode" diff --git a/jax/test_util.py b/jax/test_util.py index 32818b742..027709f39 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -784,7 +784,11 @@ def assert_dot_precision(expected_precision, fun, *args): if eqn.primitive == lax.dot_general_p] for precision in precisions: msg = "Unexpected precision: {} != {}".format(expected_precision, precision) - assert precision == expected_precision, msg + if isinstance(precision, tuple): + assert precision[0] == expected_precision, msg + assert precision[1] == expected_precision, msg + else: + assert precision == expected_precision, msg _CACHED_INDICES: Dict[int, Sequence[int]] = {} diff --git a/tests/api_test.py b/tests/api_test.py index 57a0ed6d1..0aa3f0272 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2591,23 +2591,23 @@ class APITest(jtu.JaxTestCase): with jax.default_matmul_precision("bfloat16"): x @ x # doesn't crash jaxpr = jax.make_jaxpr(op.matmul)(x, x) - self.assertIn('precision=DEFAULT', str(jaxpr)) + self.assertIn('Precision.DEFAULT', str(jaxpr)) with jax.default_matmul_precision("tensorfloat32"): jnp.dot(x, x) # doesn't crash jaxpr = jax.make_jaxpr(jnp.dot)(x, x) - self.assertIn('precision=HIGH\n', str(jaxpr)) + self.assertIn('Precision.HIGH', str(jaxpr)) with jax.default_matmul_precision("float32"): jnp.dot(x, x) # doesn't crash jaxpr = jax.make_jaxpr(jnp.dot)(x, x) - self.assertIn('precision=HIGHEST', str(jaxpr)) + self.assertIn('Precision.HIGHEST', str(jaxpr)) dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) with jax.default_matmul_precision("tensorfloat32"): dot(x, x) # doesn't crash jaxpr = jax.make_jaxpr(dot)(x, x) - self.assertIn('precision=HIGHEST', str(jaxpr)) + self.assertIn('Precision.HIGHEST', str(jaxpr)) def test_dot_precision_flag(self): x = jnp.zeros((2, 2)) @@ -2619,7 +2619,7 @@ class APITest(jtu.JaxTestCase): jaxpr = jax.make_jaxpr(jnp.dot)(x, x) finally: config.FLAGS.jax_default_matmul_precision = prev_val - self.assertIn('precision=HIGH', str(jaxpr)) + self.assertIn('Precision.HIGH', str(jaxpr)) self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) prev_val = config._read("jax_default_matmul_precision") @@ -2629,7 +2629,7 @@ class APITest(jtu.JaxTestCase): jaxpr = jax.make_jaxpr(jnp.dot)(x, x) finally: config.update('jax_default_matmul_precision', prev_val) - self.assertIn('precision=HIGH', str(jaxpr)) + self.assertIn('Precision.HIGH', str(jaxpr)) self.assertEqual(prev_val, config._read("jax_default_matmul_precision")) @unittest.skipIf(jax.lib._xla_extension_version <= 17, diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 42744ac17..b798cfffe 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -398,7 +398,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) - assert "precision=HIGHEST" in s + assert "Precision.HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": @@ -429,7 +429,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) - assert "precision=HIGHEST" in s + assert "Precision.HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(