mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Use a whitelist to restrict visibility in top-level jax namespace. (#2982)
* Use a whitelist to restrict visibility in top-level jax namespace. The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
This commit is contained in:
parent
9f04d9817b
commit
0ea22b7e19
@ -700,7 +700,7 @@
|
||||
"\n",
|
||||
"# Now we register the XLA compilation rule with JAX\n",
|
||||
"# TODO: for GPU? and TPU?\n",
|
||||
"from jax import xla\n",
|
||||
"from jax.interpreters import xla\n",
|
||||
"xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
|
||||
],
|
||||
"execution_count": 0,
|
||||
@ -876,7 +876,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"from jax import ad\n",
|
||||
"from jax.interpreters import ad\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@trace(\"multiply_add_value_and_jvp\")\n",
|
||||
@ -1529,7 +1529,7 @@
|
||||
"colab": {}
|
||||
},
|
||||
"source": [
|
||||
"from jax import batching\n",
|
||||
"from jax.interpreters import batching\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@trace(\"multiply_add_batch\")\n",
|
||||
|
@ -12,11 +12,77 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
|
||||
|
||||
from jax.version import __version__
|
||||
from jax.api import *
|
||||
from .config import config
|
||||
from .api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
argnums_partial, # TODO(phawkins): update Haiku to not use this.
|
||||
checkpoint,
|
||||
curry, # TODO(phawkins): update users to avoid this.
|
||||
custom_gradient,
|
||||
custom_jvp,
|
||||
custom_vjp,
|
||||
custom_transforms,
|
||||
defjvp,
|
||||
defjvp_all,
|
||||
defvjp,
|
||||
defvjp_all,
|
||||
device_count,
|
||||
device_get,
|
||||
device_put,
|
||||
devices,
|
||||
disable_jit,
|
||||
eval_shape,
|
||||
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
|
||||
grad,
|
||||
hessian,
|
||||
host_count,
|
||||
host_id,
|
||||
host_ids,
|
||||
jacobian,
|
||||
jacfwd,
|
||||
jacrev,
|
||||
jit,
|
||||
jvp,
|
||||
local_device_count,
|
||||
local_devices,
|
||||
linearize,
|
||||
make_jaxpr,
|
||||
mask,
|
||||
partial, # TODO(phawkins): update callers to use functools.partial.
|
||||
pmap,
|
||||
pxla, # TODO(phawkins): update users to avoid this.
|
||||
remat,
|
||||
shapecheck,
|
||||
ShapedArray,
|
||||
ShapeDtypeStruct,
|
||||
soft_pmap,
|
||||
# TODO(phawkins): hide tree* functions from jax, update callers to use
|
||||
# jax.tree_util.
|
||||
treedef_is_leaf,
|
||||
tree_flatten,
|
||||
tree_leaves,
|
||||
tree_map,
|
||||
tree_multimap,
|
||||
tree_structure,
|
||||
tree_transpose,
|
||||
tree_unflatten,
|
||||
value_and_grad,
|
||||
vjp,
|
||||
vmap,
|
||||
xla, # TODO(phawkins): update users to avoid this.
|
||||
xla_computation,
|
||||
)
|
||||
from jax import nn
|
||||
from jax import random
|
||||
import jax.numpy as np # side-effecting import sets up operator overloads
|
||||
|
||||
# TODO(phawkins): remove the `np` name.
|
||||
import jax.numpy as np # side-effecting import sets up operator overloads
|
||||
|
||||
|
||||
def _init():
|
||||
import os
|
||||
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
|
||||
|
||||
_init()
|
||||
del _init
|
||||
|
@ -12,14 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu, jit, partial
|
||||
from jax import test_util as jtu, jit
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
Loading…
x
Reference in New Issue
Block a user