mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix jax2tf import so it works with both the latest tensorflow release (2.10.0) and tf-nightly
This commit is contained in:
parent
fb8558cfdd
commit
0a69c9a27b
@ -68,7 +68,11 @@ import tensorflow as tf # type: ignore[import]
|
||||
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
|
||||
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
|
||||
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
|
||||
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
|
||||
try:
|
||||
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
|
||||
except ModuleNotFoundError:
|
||||
# This can be removed when TF 2.10 support is no longer needed.
|
||||
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
|
||||
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
|
||||
# pylint: enable=g-direct-tensorflow-import
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user