Merge pull request #6895 from gnecula:tf_uint64

PiperOrigin-RevId: 377510835
This commit is contained in:
jax authors 2021-06-04 07:49:44 -07:00
commit de9f55720d
2 changed files with 3 additions and 6 deletions

View File

@ -76,8 +76,10 @@ PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
tf.convert_to_tensor(v)
tf.constant(v)
return True
except:
return False
@ -1755,8 +1757,6 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
reducer, autograph=False).get_concrete_function(o_spec, o_spec)
if not isinstance(init_val, tf.Tensor):
assert not config.jax_enable_checks or _is_tfval(
init_val), f"Non TfVal: {init_val}"
init_val = tf.constant(init_val, operand.dtype)
out = tfxla.reduce_window(
operand,

View File

@ -987,9 +987,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
assert "min" == harness.params["computation"].__name__
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
missing_tf_kernel(dtypes=[np.uint64],
devices=("cpu", "gpu"),
modes=("eager",)),
]
@classmethod