Add version int tuple __version_info__ to JAX

This commit is contained in:
Sharad Vikram 2022-04-05 12:42:43 -07:00
parent fef367019b
commit d72a7b4054
2 changed files with 6 additions and 0 deletions

View File

@ -123,6 +123,7 @@ from jax._src.api import (
)
from jax.experimental.maps import soft_pmap as soft_pmap
from jax.version import __version__ as __version__
from jax.version import __version_info__ as __version_info__
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.

View File

@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
__version__ = "0.3.5"
__version_info__ = _version_as_tuple(__version__)
_minimum_jaxlib_version = "0.3.0"
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)