Import jax.monitoring by default.

A JAX refactoring meant this was no longer being imported by default. Restore the previous state.

PiperOrigin-RevId: 522474571
This commit is contained in:
Peter Hawkins 2023-04-06 17:03:10 -07:00 committed by jax authors
parent b1966d9fbd
commit c7b99e6ea9

View File

@ -165,6 +165,7 @@ from jax import errors as errors
from jax import image as image
from jax import lax as lax
from jax import linear_util as linear_util
from jax import monitoring as monitoring
from jax import nn as nn
from jax import numpy as numpy
from jax import ops as ops