Fix jax2tf import so it works with both the latest tensorflow release (2.10.0) and tf-nightly

This commit is contained in:
Skye Wanderman-Milne 2022-09-30 15:55:22 -07:00
parent fb8558cfdd
commit 0a69c9a27b

View File

@ -68,7 +68,11 @@ import tensorflow as tf # type: ignore[import]
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
try:
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
except ModuleNotFoundError:
# This can be removed when TF 2.10 support is no longer needed.
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import