Use a whitelist to restrict visibility in top-level jax namespace. (#2982)

* Use a whitelist to restrict visibility in top-level jax namespace.

The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
This commit is contained in:
Peter Hawkins 2020-05-07 17:24:19 -04:00 committed by GitHub
parent 9f04d9817b
commit 0ea22b7e19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 10 deletions

View File

@ -700,7 +700,7 @@
"\n",
"# Now we register the XLA compilation rule with JAX\n",
"# TODO: for GPU? and TPU?\n",
"from jax import xla\n",
"from jax.interpreters import xla\n",
"xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
],
"execution_count": 0,
@ -876,7 +876,7 @@
"colab": {}
},
"source": [
"from jax import ad\n",
"from jax.interpreters import ad\n",
"\n",
"\n",
"@trace(\"multiply_add_value_and_jvp\")\n",
@ -1529,7 +1529,7 @@
"colab": {}
},
"source": [
"from jax import batching\n",
"from jax.interpreters import batching\n",
"\n",
"\n",
"@trace(\"multiply_add_batch\")\n",

View File

@ -12,11 +12,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
from jax.version import __version__
from jax.api import *
from .config import config
from .api import (
ad, # TODO(phawkins): update users to avoid this.
argnums_partial, # TODO(phawkins): update Haiku to not use this.
checkpoint,
curry, # TODO(phawkins): update users to avoid this.
custom_gradient,
custom_jvp,
custom_vjp,
custom_transforms,
defjvp,
defjvp_all,
defvjp,
defvjp_all,
device_count,
device_get,
device_put,
devices,
disable_jit,
eval_shape,
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
grad,
hessian,
host_count,
host_id,
host_ids,
jacobian,
jacfwd,
jacrev,
jit,
jvp,
local_device_count,
local_devices,
linearize,
make_jaxpr,
mask,
partial, # TODO(phawkins): update callers to use functools.partial.
pmap,
pxla, # TODO(phawkins): update users to avoid this.
remat,
shapecheck,
ShapedArray,
ShapeDtypeStruct,
soft_pmap,
# TODO(phawkins): hide tree* functions from jax, update callers to use
# jax.tree_util.
treedef_is_leaf,
tree_flatten,
tree_leaves,
tree_map,
tree_multimap,
tree_structure,
tree_transpose,
tree_unflatten,
value_and_grad,
vjp,
vmap,
xla, # TODO(phawkins): update users to avoid this.
xla_computation,
)
from jax import nn
from jax import random
import jax.numpy as np # side-effecting import sets up operator overloads
# TODO(phawkins): remove the `np` name.
import jax.numpy as np # side-effecting import sets up operator overloads
def _init():
import os
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
_init()
del _init

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
from jax import numpy as jnp
from jax import test_util as jtu, jit, partial
from jax import test_util as jtu, jit
from jax.config import config
config.parse_flags_with_absl()