Import jax.nn.functions by default to fix breakage.

This commit is contained in:
Peter Hawkins 2020-10-17 14:51:39 -04:00
parent 3ddd3905a4
commit b07848359b

View File

@ -15,6 +15,10 @@
"""Common functions for neural network libraries."""
# flake8: noqa: F401
# TODO(phawkins): remove this import after fixing callers
from . import functions
from . import initializers
from jax._src.nn.functions import (
celu,