mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #6895 from gnecula:tf_uint64
PiperOrigin-RevId: 377510835
This commit is contained in:
commit
de9f55720d
@ -76,8 +76,10 @@ PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
|
||||
|
||||
|
||||
def _is_tfval(v: TfVal) -> bool:
|
||||
if isinstance(v, (tf.Tensor, tf.Variable)):
|
||||
return True
|
||||
try:
|
||||
tf.convert_to_tensor(v)
|
||||
tf.constant(v)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
@ -1755,8 +1757,6 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions,
|
||||
reducer, autograph=False).get_concrete_function(o_spec, o_spec)
|
||||
|
||||
if not isinstance(init_val, tf.Tensor):
|
||||
assert not config.jax_enable_checks or _is_tfval(
|
||||
init_val), f"Non TfVal: {init_val}"
|
||||
init_val = tf.constant(init_val, operand.dtype)
|
||||
out = tfxla.reduce_window(
|
||||
operand,
|
||||
|
@ -987,9 +987,6 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
assert "min" == harness.params["computation"].__name__
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
missing_tf_kernel(dtypes=[np.uint64],
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager",)),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user