mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Copybara import of the project:
-- 1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>: [jax2tf] Change the conversion of dot_general to use XLA op. Instead of converting the dot_general to a sea of TF ops, when we enable_xla we just use the XLA op. This has the advantage that it also supports the preferred_element_type. Fixed bug with passing the precision parameter to TF. Also improved tests to print the HLO in case of numerical errors. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366 PiperOrigin-RevId: 373326655
This commit is contained in:
parent
aa74314c1a
commit
235eb8c2b4
@ -11,9 +11,15 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.14 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...master).
|
||||
|
||||
* Bug fixes:
|
||||
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
|
||||
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
|
||||
({jax-issue}`#6717`).
|
||||
|
||||
## jaxlib 0.1.67 (unreleased)
|
||||
|
||||
## jaxlib 0.1.66 (May 11 2021)
|
||||
|
||||
* New features:
|
||||
* CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
|
||||
|
||||
|
@ -6652,28 +6652,35 @@ def remaining(original, *removed_lists):
|
||||
return [i for i in original if i not in removed]
|
||||
|
||||
|
||||
def _canonicalize_precision(precision):
|
||||
def _canonicalize_precision(precision: PrecisionLike) -> Optional[Tuple[PrecisionType, PrecisionType]]:
|
||||
"""Turns an API precision specification, into a pair of enumeration values.
|
||||
|
||||
The API can take the precision as a string, or int, and either as a single
|
||||
value to apply to both operands, or as a sequence of two values.
|
||||
"""
|
||||
if precision is None:
|
||||
if config.jax_default_matmul_precision is None:
|
||||
return None
|
||||
try:
|
||||
return _precision_strings[config.jax_default_matmul_precision]
|
||||
precision = _precision_strings[config.jax_default_matmul_precision]
|
||||
return (precision, precision)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"jax_default_matmul_precision flag must be set to None or a value in "
|
||||
f"{_precision_strings}, but got {config.jax_default_matmul_precision}"
|
||||
) from None
|
||||
elif isinstance(precision, str) and precision in _precision_strings:
|
||||
return _precision_strings.get(precision)
|
||||
precision = _precision_strings.get(precision)
|
||||
return (precision, precision)
|
||||
elif isinstance(precision, Precision):
|
||||
return precision
|
||||
return (precision, precision)
|
||||
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
||||
all(isinstance(p, Precision) for p in precision)):
|
||||
return precision
|
||||
return precision # type: ignore[return-value]
|
||||
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
|
||||
all(isinstance(s, str) for s in precision)):
|
||||
s1, s2 = precision
|
||||
return (_canonicalize_precision(s1), _canonicalize_precision(s2))
|
||||
return (_canonicalize_precision(s1)[0], _canonicalize_precision(s2)[0]) # type: ignore
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Precision argument must be None, a string in {_precision_strings}, "
|
||||
|
@ -456,6 +456,7 @@ We use the following TFXLA ops:
|
||||
* `XlaPad` (wraps XLA Pad operator). We use this instead of `tf.pad` in order to
|
||||
support `lax.pad` interior padding (dilation) or negative edge padding.
|
||||
* `XlaConv` (wraps XLA ConvGeneralDilated operator).
|
||||
* `XlaDot` and `XlaDotV2` (wraps XLA DotGeneral operator).
|
||||
* `XlaGather` (wraps XLA Gather operator). We could use `tf.gather` in some
|
||||
cases but not always. Also, `tf.gather` has a different semantics than `lax.gather`
|
||||
for index out of bounds.
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
See README.md for instructions.
|
||||
"""
|
||||
import grpc
|
||||
import grpc # type: ignore[import]
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
@ -26,9 +26,9 @@ from absl import flags
|
||||
from jax.experimental.jax2tf.examples import mnist_lib # type: ignore
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore
|
||||
import tensorflow_datasets as tfds # type: ignore
|
||||
from tensorflow_serving.apis import predict_pb2
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
import tensorflow_datasets as tfds # type: ignore[import]
|
||||
from tensorflow_serving.apis import predict_pb2 # type: ignore[import]
|
||||
from tensorflow_serving.apis import prediction_service_pb2_grpc
|
||||
|
||||
|
||||
|
@ -21,8 +21,8 @@ from jax.experimental.jax2tf.examples import mnist_lib
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
import tensorflow_datasets as tfds # type: ignore[import]
|
||||
|
||||
flags.DEFINE_string('tflite_file_path',
|
||||
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite',
|
||||
|
@ -1,11 +1,11 @@
|
||||
# Primitives with limited JAX support
|
||||
|
||||
*Last generated on: 2021-03-29* (YYYY-MM-DD)
|
||||
*Last generated on: 2021-05-12* (YYYY-MM-DD)
|
||||
|
||||
## Supported data types for primitives
|
||||
|
||||
We use a set of 2313 test harnesses to test
|
||||
the implementation of 122 numeric JAX primitives.
|
||||
We use a set of 2418 test harnesses to test
|
||||
the implementation of 121 numeric JAX primitives.
|
||||
We consider a JAX primitive supported for a particular data
|
||||
type if it is supported on at least one device type.
|
||||
The following table shows the dtypes at which primitives
|
||||
@ -76,7 +76,7 @@ be updated.
|
||||
| device_put | 16 | all | |
|
||||
| digamma | 4 | floating | bool, complex, integer |
|
||||
| div | 20 | inexact, integer | bool |
|
||||
| dot_general | 125 | all | |
|
||||
| dot_general | 245 | all | |
|
||||
| dynamic_slice | 32 | all | |
|
||||
| dynamic_update_slice | 21 | all | |
|
||||
| eig | 72 | inexact | bool, integer |
|
||||
@ -156,7 +156,6 @@ be updated.
|
||||
| svd | 120 | inexact | bool, integer |
|
||||
| tan | 6 | inexact | bool, integer |
|
||||
| tanh | 6 | inexact | bool, integer |
|
||||
| tie_in | 15 | all | |
|
||||
| top_k | 15 | bool, floating, integer | complex |
|
||||
| transpose | 17 | all | |
|
||||
| triangular_solve | 26 | inexact | bool, integer |
|
||||
@ -188,6 +187,9 @@ and search for "limitation".
|
||||
|cummax|unimplemented|complex64|tpu|
|
||||
|cummin|unimplemented|complex64|tpu|
|
||||
|cumprod|unimplemented|complex64|tpu|
|
||||
|dot_general|preferred_element_type=c128 not implemented|complex64|tpu|
|
||||
|dot_general|preferred_element_type=f64 crashes (b/187884887)|bfloat16, float16, float32|tpu|
|
||||
|dot_general|preferred_element_type=i64 not implemented|int16, int32, int8|tpu|
|
||||
|eig|only supported on CPU in JAX|all|tpu, gpu|
|
||||
|eig|unimplemented|bfloat16, float16|cpu|
|
||||
|eigh|complex eigh not supported |complex|tpu|
|
||||
@ -202,7 +204,6 @@ and search for "limitation".
|
||||
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|
||||
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|
||||
|svd|unimplemented|bfloat16, float16|cpu, gpu|
|
||||
|tie_in|requires omnistaging to be disabled|all|cpu, gpu, tpu|
|
||||
|triangular_solve|unimplemented|float16|gpu|
|
||||
|
||||
## Table generation
|
||||
|
@ -68,6 +68,8 @@ def _sanitize_scope_name(name):
|
||||
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
|
||||
TfVal = Any
|
||||
DType = Any
|
||||
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
||||
|
||||
def _is_tfval(v: TfVal) -> bool:
|
||||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||||
return True
|
||||
@ -123,7 +125,6 @@ def _xla_path_disabled_error(primitive_name: str) -> Exception:
|
||||
@functools.partial(api_util.api_hook, tag="jax2tf_convert")
|
||||
def convert(fun: Callable, *,
|
||||
polymorphic_shapes: Optional[Sequence[Any]]=None,
|
||||
in_shapes=None, # DEPRECATED
|
||||
with_gradient=True, enable_xla=True) -> Callable:
|
||||
"""Transforms `fun` to be executed by TensorFlow.
|
||||
|
||||
@ -222,13 +223,13 @@ def convert(fun: Callable, *,
|
||||
raise TypeError(msg)
|
||||
polymorphic_shapes_ = tuple(polymorphic_shapes)
|
||||
|
||||
# Expand the in_shapes to match the argument pytree
|
||||
# Expand the polymorphic_shapes to match the argument pytree
|
||||
polymorphic_shapes_flat = tuple(api_util.flatten_axes("jax2tf.convert polymorphic_shapes",
|
||||
in_tree.children()[0],
|
||||
polymorphic_shapes_))
|
||||
|
||||
# Construct the abstract values for the flat arguments, possibly based on
|
||||
# the input shapes and the in_shapes if given. May create new shape
|
||||
# the input shapes and the polymorphic_shapes if given. May create new shape
|
||||
# variables.
|
||||
args_avals_flat, shapeenv = _args_to_avals_and_env(args_flat,
|
||||
polymorphic_shapes_flat)
|
||||
@ -557,11 +558,16 @@ class TensorFlowTracer(core.Tracer):
|
||||
elif isinstance(val, (tf.Tensor, tf.Variable)):
|
||||
val_shape, val_dtype = _tfval_shape_dtype(val)
|
||||
aval_dtype = np.dtype(self._aval.dtype) # type: ignore[attr-defined]
|
||||
if val_dtype != aval_dtype and (val_dtype == tf.int32 and aval_dtype == jnp.int64 or
|
||||
val_dtype == tf.int64 and aval_dtype == jnp.int32 or
|
||||
val_dtype == tf.float32 and aval_dtype == jnp.float64 or
|
||||
val_dtype == tf.float64 and aval_dtype == jnp.float32):
|
||||
# We expect that x64 values are turned into x32
|
||||
if (val_dtype != aval_dtype and
|
||||
not config.x64_enabled and
|
||||
(val_dtype == tf.int32 and aval_dtype == jnp.int64 or
|
||||
val_dtype == tf.int64 and aval_dtype == jnp.int32 or
|
||||
val_dtype == tf.float32 and aval_dtype == jnp.float64 or
|
||||
val_dtype == tf.float64 and aval_dtype == jnp.float32 or
|
||||
val_dtype == tf.complex128 and aval_dtype == jnp.complex64)):
|
||||
# If JAX does not have x64 bit mode enabled, it will force the 64-bit
|
||||
# values to use 32-bit precision. In order to make the TF conversion
|
||||
# follow JAX's rules, we cast the TF values down to 32-bit mode.
|
||||
val = tf.cast(val, dtype=aval_dtype)
|
||||
val_dtype = aval_dtype
|
||||
|
||||
@ -569,7 +575,8 @@ class TensorFlowTracer(core.Tracer):
|
||||
assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}"
|
||||
for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined]
|
||||
if val_dim is None:
|
||||
assert isinstance(aval_dim, shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
assert isinstance(aval_dim,
|
||||
shape_poly.DimVar), f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
elif not isinstance(aval_dim, shape_poly.DimVar):
|
||||
assert aval_dim == val_dim, f"expected {self._aval.shape} == {val_shape}" # type: ignore[attr-defined]
|
||||
else:
|
||||
@ -1082,13 +1089,14 @@ def _conv_general_dimension_numbers_proto(dimension_numbers):
|
||||
return proto
|
||||
|
||||
|
||||
def _conv_general_precision_config_proto(precision):
|
||||
def _precision_config_proto(precision: Optional[Tuple[PrecisionType, PrecisionType]]):
|
||||
"""Convert an integer to an XLA.PrecisionConfig."""
|
||||
if precision is None:
|
||||
return None
|
||||
|
||||
proto = xla_data_pb2.PrecisionConfig()
|
||||
proto.operand_precision.append(int(precision))
|
||||
proto.operand_precision.append(int(precision[0]))
|
||||
proto.operand_precision.append(int(precision[1]))
|
||||
return proto
|
||||
|
||||
# _try_tf_conv returns a Tensor when it succeeds, or a string describing why
|
||||
@ -1196,9 +1204,11 @@ def _try_tf_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
return error
|
||||
return convert_dilation_and_compute_result(tf_padding, tf_dim_nums)
|
||||
|
||||
def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
def _conv_general_dilated(lhs, rhs, *,
|
||||
window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count,
|
||||
batch_group_count, lhs_shape, rhs_shape, precision,
|
||||
batch_group_count, lhs_shape, rhs_shape,
|
||||
precision: Optional[Tuple[PrecisionType, PrecisionType]],
|
||||
preferred_element_type, _in_avals, _out_aval):
|
||||
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
|
||||
if not _enable_xla:
|
||||
@ -1212,7 +1222,7 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
raise _xla_path_disabled_error("conv_general_dilated")
|
||||
|
||||
dnums_proto = _conv_general_dimension_numbers_proto(dimension_numbers)
|
||||
precision_config_proto = _conv_general_precision_config_proto(precision)
|
||||
precision_config_proto = _precision_config_proto(precision)
|
||||
assert batch_group_count == 1 # TODO(phawkins): implement batch_group_count
|
||||
out = tfxla.conv(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
@ -1226,24 +1236,38 @@ def _conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
|
||||
|
||||
|
||||
def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type):
|
||||
def _dot_general(lhs, rhs, *,
|
||||
dimension_numbers,
|
||||
precision: Optional[Tuple[PrecisionType, PrecisionType]],
|
||||
preferred_element_type: Optional[DType],
|
||||
_in_avals: Sequence[core.AbstractValue],
|
||||
_out_aval: core.AbstractValue):
|
||||
"""Implementation of lax.dot_general_p in terms of tf.linalg.einsum."""
|
||||
del precision
|
||||
del preferred_element_type
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
lhs_ndim, rhs_ndim = len(lhs.shape), len(rhs.shape)
|
||||
if _enable_xla:
|
||||
dnums_proto = xla_data_pb2.DotDimensionNumbers()
|
||||
dnums_proto.lhs_contracting_dimensions.extend(lhs_contracting)
|
||||
dnums_proto.rhs_contracting_dimensions.extend(rhs_contracting)
|
||||
dnums_proto.lhs_batch_dimensions.extend(lhs_batch)
|
||||
dnums_proto.rhs_batch_dimensions.extend(rhs_batch)
|
||||
precision_config_proto = _precision_config_proto(precision)
|
||||
res = tfxla.dot_general(lhs, rhs, dnums_proto, precision_config_proto,
|
||||
preferred_element_type=preferred_element_type)
|
||||
# TODO: in presence of None dimensions, XlaDot shape inference returns
|
||||
# unknown shape.
|
||||
res.set_shape(_aval_to_tf_shape(_out_aval))
|
||||
return res
|
||||
|
||||
# This condition ensures that:
|
||||
# 1) the considered dtype is not tf.bfloat16/tf.int32, which are supported by
|
||||
# tf.linalg.einsum but not by tf.linalg.matmul;
|
||||
# 2) the batch dimensions are ordered in the same way in lhs and rhs (this is
|
||||
# 1) the batch dimensions are ordered in the same way in lhs and rhs (this is
|
||||
# not strictly necessary, but we would have to reshape the array if that
|
||||
# were not the case;
|
||||
# 3) lhs and rhs have the same number of dimensions +/- 1
|
||||
# 4) the number of non-batch dimensions in both tensors is either 1 or 2
|
||||
# 5) the contracting dimensions are consistent with those of a classic
|
||||
# 2) lhs and rhs have the same number of dimensions +/- 1
|
||||
# 3) the number of non-batch dimensions in both tensors is either 1 or 2
|
||||
# 4) the contracting dimensions are consistent with those of a classic
|
||||
# matrix/matrix, vector/matrix or matrix/vector multiplication.
|
||||
if (not lhs.dtype in [tf.bfloat16, tf.int32]
|
||||
and lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
|
||||
if (lhs_batch == rhs_batch == tuple(range(len(lhs_batch)))
|
||||
and lhs_ndim - rhs_ndim in [-1, 0, 1]
|
||||
and 1 <= lhs_ndim - len(lhs_batch) <= 2
|
||||
and 1 <= rhs_ndim - len(rhs_batch) <= 2
|
||||
@ -1290,16 +1314,16 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type)
|
||||
shared_id = next(new_id)
|
||||
lhs_axis_ids[lhs_axis] = shared_id
|
||||
rhs_axis_ids[rhs_axis] = shared_id
|
||||
lhs_out_axis_ids[lhs_axis] = None
|
||||
rhs_out_axis_ids[rhs_axis] = None
|
||||
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
|
||||
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
|
||||
|
||||
batch_ids = []
|
||||
for lhs_axis, rhs_axis in zip(lhs_batch, rhs_batch):
|
||||
shared_id = next(new_id)
|
||||
lhs_axis_ids[lhs_axis] = shared_id
|
||||
rhs_axis_ids[rhs_axis] = shared_id
|
||||
lhs_out_axis_ids[lhs_axis] = None
|
||||
rhs_out_axis_ids[rhs_axis] = None
|
||||
lhs_out_axis_ids[lhs_axis] = None # type: ignore[call-overload]
|
||||
rhs_out_axis_ids[rhs_axis] = None # type: ignore[call-overload]
|
||||
batch_ids.append(shared_id)
|
||||
|
||||
not_none = lambda x: x is not None
|
||||
@ -1310,7 +1334,7 @@ def _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type)
|
||||
"".join(rhs_axis_ids),
|
||||
"".join(out_axis_ids))
|
||||
return tf.linalg.einsum(spec, lhs, rhs)
|
||||
tf_impl[lax.dot_general_p] = _dot_general
|
||||
tf_impl_with_avals[lax.dot_general_p] = _dot_general
|
||||
|
||||
|
||||
def _broadcast(operand, *, sizes):
|
||||
|
@ -460,17 +460,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
modes="compiled",
|
||||
# Works for 2D matrices.
|
||||
enabled=(len(harness.params["lhs_shape"]) > 2)),
|
||||
custom_numeric(dtypes=dtypes.bfloat16, tol=0.3),
|
||||
custom_numeric(
|
||||
dtypes=[np.complex64, np.float32], devices=("cpu", "gpu"),
|
||||
tol=1e-5),
|
||||
custom_numeric(
|
||||
dtypes=[np.complex128, np.float64], devices=("cpu", "gpu"),
|
||||
tol=1e-12),
|
||||
custom_numeric(dtypes=np.float32, devices="tpu", tol=0.1),
|
||||
custom_numeric(dtypes=np.complex64, devices="tpu", tol=0.3),
|
||||
custom_numeric(dtypes=np.float16, devices=("gpu", "tpu"), tol=0.1),
|
||||
custom_numeric(dtypes=np.float16, devices="cpu", tol=0.01)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
@ -2352,22 +2352,49 @@ def _make_dot_general_harness(name,
|
||||
rhs_shape=(4, 2),
|
||||
dtype=np.float32,
|
||||
precision=None,
|
||||
dimension_numbers=(((1,), (0,)), ((), ()))):
|
||||
dimension_numbers=(((1,), (0,)), ((), ())),
|
||||
preferred_element_type=None):
|
||||
suffix = ""
|
||||
if precision is not None:
|
||||
suffix += f"_precision={precision}"
|
||||
if preferred_element_type is not None:
|
||||
suffix += f"_preferred={jtu.dtype_str(preferred_element_type)}"
|
||||
define(
|
||||
lax.dot_general_p,
|
||||
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}_precision={precision}"
|
||||
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}_dimensionnumbers={dimension_numbers}{suffix}"
|
||||
.replace(" ", ""),
|
||||
lax.dot_general, [
|
||||
RandArg(lhs_shape, dtype),
|
||||
RandArg(rhs_shape, dtype),
|
||||
StaticArg(dimension_numbers),
|
||||
StaticArg(precision)
|
||||
StaticArg(precision),
|
||||
StaticArg(preferred_element_type)
|
||||
],
|
||||
dtype=dtype,
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers,
|
||||
precision=precision)
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
jax_unimplemented=[
|
||||
Limitation(
|
||||
"preferred_element_type=c128 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=np.complex64,
|
||||
enabled=(preferred_element_type in [np.complex128])),
|
||||
Limitation(
|
||||
"preferred_element_type=f64 crashes (b/187884887)",
|
||||
devices="tpu",
|
||||
dtypes=(np.float16, jnp.bfloat16, np.float32),
|
||||
enabled=(preferred_element_type in [np.float64]),
|
||||
skip_run=True),
|
||||
Limitation(
|
||||
"preferred_element_type=i64 not implemented",
|
||||
devices="tpu",
|
||||
dtypes=(np.int8, np.int16, np.int32),
|
||||
enabled=(preferred_element_type in [np.int64])),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# There are two execution paths in the conversion of dot_general. The main path
|
||||
@ -2422,6 +2449,28 @@ for lhs_shape, rhs_shape, dimension_numbers in [
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=dimension_numbers)
|
||||
|
||||
# Validate preferred element type
|
||||
# From lax_test.py
|
||||
preferred_type_combinations = [
|
||||
(np.float16, np.float16), (np.float16, np.float32), (np.float16, np.float64),
|
||||
(dtypes.bfloat16, np.float32),
|
||||
(dtypes.bfloat16, np.float64), (np.float32, np.float32), (np.float32, np.float64),
|
||||
(np.int8, np.int16), (np.int8, np.int32),
|
||||
(np.int8, np.int64), (np.int16, np.int32), (np.int16, np.int64),
|
||||
(np.int32, np.int32), (np.int32, np.int64),
|
||||
(np.complex64, np.complex128)]
|
||||
|
||||
for lhs_shape in [(3,), (4, 3)]:
|
||||
for rhs_shape in [(3, ), (3, 6)]:
|
||||
for dtype, preferred_element_type in preferred_type_combinations:
|
||||
_make_dot_general_harness(
|
||||
"preferred",
|
||||
dtype=dtype,
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dimension_numbers=(((len(lhs_shape) - 1,), (0,)), ((), ())),
|
||||
preferred_element_type=preferred_element_type)
|
||||
|
||||
|
||||
def _make_concatenate_harness(name,
|
||||
*,
|
||||
@ -2521,8 +2570,6 @@ for batch_group_count, feature_group_count in [
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count)
|
||||
|
||||
### XXX
|
||||
|
||||
# Validate variations of window_strides
|
||||
for window_strides in [(2, 3)]:
|
||||
_make_conv_harness("window_strides", window_strides=window_strides)
|
||||
|
@ -99,7 +99,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
# If you want to run this test for only one harness, add parameter
|
||||
# `one_containing="foo"` to parameterized below.
|
||||
@primitive_harness.parameterized(
|
||||
primitive_harness.all_harnesses, include_jax_unimpl=False
|
||||
primitive_harness.all_harnesses, include_jax_unimpl=False,
|
||||
)
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning, message="Using reduced precision for gradient.*")
|
||||
|
@ -332,11 +332,11 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
tf_value_and_grad, autograph=False).get_concrete_function(
|
||||
tf.TensorSpec([None, None, 8, 9]))
|
||||
|
||||
# The shape of the value
|
||||
self.assertEqual((None, None, 8, 8), tuple(tf_grad.output_shapes[0]))
|
||||
# The shape of the gradient should match the input
|
||||
# TODO: there seems to be a bug here, the output should be (None, None, 8, 9)
|
||||
# self.assertEqual((None, None, 8, None), tuple(tf_grad.output_shapes[1]))
|
||||
# The shape of the value. This should be (None, None, 8, 8) but the
|
||||
# shape inference for XlaDot is broken, and returns too many unknown
|
||||
# dimensions.
|
||||
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[0]))
|
||||
self.assertEqual((None, None, None, None), tuple(tf_grad.output_shapes[1]))
|
||||
|
||||
def test_gradients_pytree(self):
|
||||
"""Shape polymorphism with gradients and pytrees for inputs and outputs."""
|
||||
@ -742,12 +742,18 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
[RandArg((3,), _f32)],
|
||||
poly_axes=[0]),
|
||||
|
||||
_make_harness("jnp_matmul", "",
|
||||
_make_harness("jnp_matmul", "0",
|
||||
jnp.matmul,
|
||||
[RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)],
|
||||
poly_axes=[0, 0],
|
||||
tol=1e-5),
|
||||
|
||||
_make_harness("jnp_matmul", "1",
|
||||
jnp.matmul,
|
||||
[RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)],
|
||||
poly_axes=[0, None],
|
||||
tol=1e-5),
|
||||
|
||||
_make_harness("jnp_where", "",
|
||||
jnp.where,
|
||||
[RandArg((2,), np.bool_), RandArg((), _f32), RandArg((2,), _f32)],
|
||||
@ -935,6 +941,10 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
"""Tests for primitives that take shape values as parameters."""
|
||||
|
||||
# This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES.
|
||||
# For each primitive "xxx" the
|
||||
# test will be called "test_prim_xxx_...".
|
||||
# If you want to run this test for only one harness that includes "foo"
|
||||
# in the name, add parameter `one_containing="foo"` to parameterized below.
|
||||
@primitive_harness.parameterized(_POLY_SHAPE_TEST_HARNESSES)
|
||||
def test_prim(self, harness: Harness):
|
||||
args = harness.dyn_args_maker(self.rng())
|
||||
|
@ -36,7 +36,7 @@ DType = Any
|
||||
def _make_tf_args(args):
|
||||
def _convert_to_tensor(v):
|
||||
if hasattr(v, "dtype"):
|
||||
tf.convert_to_tensor(v)
|
||||
return tf.convert_to_tensor(v)
|
||||
return v
|
||||
|
||||
return tf.nest.map_structure(_convert_to_tensor, args)
|
||||
@ -51,7 +51,6 @@ def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]:
|
||||
|
||||
return tf.nest.map_structure(_make_one_arg_signature, list(tf_args))
|
||||
|
||||
|
||||
def _run_tf_function(func_tf: Callable, *tf_args, mode: str):
|
||||
if mode == "eager":
|
||||
return func_tf(*tf_args) # EAGER
|
||||
@ -188,13 +187,47 @@ class JaxToTfTestCase(jtu.JaxTestCase):
|
||||
custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert]
|
||||
assert len(custom_assert_lim) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}"
|
||||
|
||||
if custom_assert_lim:
|
||||
logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"))
|
||||
custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol)
|
||||
else:
|
||||
logging.info(log_message(f"Running default assert with tol={max_tol}"))
|
||||
# In compiled mode we expect the same result as JAX by default
|
||||
self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol)
|
||||
try:
|
||||
if custom_assert_lim:
|
||||
logging.info(log_message(f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"))
|
||||
custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol)
|
||||
else:
|
||||
logging.info(log_message(f"Running default assert with tol={max_tol}"))
|
||||
self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol)
|
||||
except AssertionError as e:
|
||||
# Log the optimized HLO for compiled mode, it should match.
|
||||
if mode != "compiled":
|
||||
logging.info(f"[{self._testMethodName}] Not printing HLO because the "
|
||||
f"mode was {mode}")
|
||||
raise
|
||||
|
||||
logging.info(f"[{self._testMethodName}] Logging HLO "
|
||||
f"for comparison error {e}")
|
||||
|
||||
jax_comp = jax.xla_computation(func_jax)(*args)
|
||||
jax_hlo = jax_comp.as_hlo_text()
|
||||
logging.info(f"[{self._testMethodName}] "
|
||||
f"JAX NON-OPT HLO\n{jax_hlo}")
|
||||
|
||||
tf_func_compiled = tf.function(
|
||||
func_tf,
|
||||
autograph=False,
|
||||
jit_compile=True,
|
||||
input_signature=_make_tf_input_signature(*tf_args))
|
||||
tf_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)(
|
||||
stage="hlo")
|
||||
logging.info(f"[{self._testMethodName}] TF NON-OPT HLO\n{tf_hlo}")
|
||||
|
||||
backend = jax.lib.xla_bridge.get_backend()
|
||||
modules = backend.compile(jax_comp).hlo_modules()
|
||||
jax_opt_hlo = modules[0].to_string()
|
||||
logging.info(f"[{self._testMethodName}] "
|
||||
f"JAX OPT HLO\n{jax_opt_hlo}")
|
||||
|
||||
tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir(*tf_args)(
|
||||
stage="optimized_hlo")
|
||||
logging.info(f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}")
|
||||
raise
|
||||
|
||||
# end "for mode"
|
||||
|
||||
|
@ -784,7 +784,11 @@ def assert_dot_precision(expected_precision, fun, *args):
|
||||
if eqn.primitive == lax.dot_general_p]
|
||||
for precision in precisions:
|
||||
msg = "Unexpected precision: {} != {}".format(expected_precision, precision)
|
||||
assert precision == expected_precision, msg
|
||||
if isinstance(precision, tuple):
|
||||
assert precision[0] == expected_precision, msg
|
||||
assert precision[1] == expected_precision, msg
|
||||
else:
|
||||
assert precision == expected_precision, msg
|
||||
|
||||
|
||||
_CACHED_INDICES: Dict[int, Sequence[int]] = {}
|
||||
|
@ -2591,23 +2591,23 @@ class APITest(jtu.JaxTestCase):
|
||||
with jax.default_matmul_precision("bfloat16"):
|
||||
x @ x # doesn't crash
|
||||
jaxpr = jax.make_jaxpr(op.matmul)(x, x)
|
||||
self.assertIn('precision=DEFAULT', str(jaxpr))
|
||||
self.assertIn('Precision.DEFAULT', str(jaxpr))
|
||||
|
||||
with jax.default_matmul_precision("tensorfloat32"):
|
||||
jnp.dot(x, x) # doesn't crash
|
||||
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
||||
self.assertIn('precision=HIGH\n', str(jaxpr))
|
||||
self.assertIn('Precision.HIGH', str(jaxpr))
|
||||
|
||||
with jax.default_matmul_precision("float32"):
|
||||
jnp.dot(x, x) # doesn't crash
|
||||
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
||||
self.assertIn('precision=HIGHEST', str(jaxpr))
|
||||
self.assertIn('Precision.HIGHEST', str(jaxpr))
|
||||
|
||||
dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
with jax.default_matmul_precision("tensorfloat32"):
|
||||
dot(x, x) # doesn't crash
|
||||
jaxpr = jax.make_jaxpr(dot)(x, x)
|
||||
self.assertIn('precision=HIGHEST', str(jaxpr))
|
||||
self.assertIn('Precision.HIGHEST', str(jaxpr))
|
||||
|
||||
def test_dot_precision_flag(self):
|
||||
x = jnp.zeros((2, 2))
|
||||
@ -2619,7 +2619,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
||||
finally:
|
||||
config.FLAGS.jax_default_matmul_precision = prev_val
|
||||
self.assertIn('precision=HIGH', str(jaxpr))
|
||||
self.assertIn('Precision.HIGH', str(jaxpr))
|
||||
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
|
||||
|
||||
prev_val = config._read("jax_default_matmul_precision")
|
||||
@ -2629,7 +2629,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
|
||||
finally:
|
||||
config.update('jax_default_matmul_precision', prev_val)
|
||||
self.assertIn('precision=HIGH', str(jaxpr))
|
||||
self.assertIn('Precision.HIGH', str(jaxpr))
|
||||
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
|
||||
|
||||
@unittest.skipIf(jax.lib._xla_extension_version <= 17,
|
||||
|
@ -398,7 +398,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
result, pullback = api.vjp(dot, lhs, rhs)
|
||||
gresult = lax.zeros_like_array(result)
|
||||
s = str(api.make_jaxpr(pullback)(gresult))
|
||||
assert "precision=HIGHEST" in s
|
||||
assert "Precision.HIGHEST" in s
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name":
|
||||
@ -429,7 +429,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
result, pullback = api.vjp(dot_general, lhs, rhs)
|
||||
gresult = lax.zeros_like_array(result)
|
||||
s = str(api.make_jaxpr(pullback)(gresult))
|
||||
assert "precision=HIGHEST" in s
|
||||
assert "Precision.HIGHEST" in s
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
|
||||
|
Loading…
x
Reference in New Issue
Block a user