mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Import jax.nn.functions by default to fix breakage.
This commit is contained in:
parent
3ddd3905a4
commit
b07848359b
@ -15,6 +15,10 @@
|
||||
"""Common functions for neural network libraries."""
|
||||
|
||||
# flake8: noqa: F401
|
||||
|
||||
# TODO(phawkins): remove this import after fixing callers
|
||||
from . import functions
|
||||
|
||||
from . import initializers
|
||||
from jax._src.nn.functions import (
|
||||
celu,
|
||||
|
Loading…
x
Reference in New Issue
Block a user