Remove lax_numpy from jax.numpy namespace

This is a private module that was inadvertently exported in the past.
This commit is contained in:
Jake VanderPlas 2022-03-25 15:02:45 -07:00
parent f57e78e240
commit f4d240c036
2 changed files with 1 additions and 3 deletions

View File

@ -14,6 +14,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`).
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`#10029`).
* Deprecations:
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`#9899`).

View File

@ -436,9 +436,6 @@ from jax._src.numpy.ufuncs import (
from jax._src.numpy.vectorize import vectorize as vectorize
# TODO(phawkins): remove this import after fixing users.
from jax._src.numpy import lax_numpy
# Module initialization is encapsulated in a function to avoid accidental
# namespace pollution.
def _init():