diff --git a/jax/experimental/jax2tf/__init__.py b/jax/experimental/jax2tf/__init__.py index 6a2dd2fd7..918c63a95 100644 --- a/jax/experimental/jax2tf/__init__.py +++ b/jax/experimental/jax2tf/__init__.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val, - split_to_logical_devices, PolyShape) -from jax.experimental.jax2tf.call_tf import call_tf +from jax.experimental.jax2tf.jax2tf import ( + convert as convert, + dtype_of_val as dtype_of_val, + split_to_logical_devices as split_to_logical_devices, + PolyShape as PolyShape +) +from jax.experimental.jax2tf.call_tf import call_tf as call_tf