Remove jax._src from JAX namespace.

This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX.

PiperOrigin-RevId: 473243243
This commit is contained in:
Peter Hawkins 2022-09-09 07:05:30 -07:00 committed by jax authors
parent bc59bd1ddc
commit 40c80d7d0a
4 changed files with 8 additions and 3 deletions

View File

@ -15,6 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
{jax-issue}`#7733`) is stable and public. See [the
overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs
for {mod}`jax.stages`.
* Breaking changes
* `jax._src` is no longer imported into the from the public `jax` namespace.
This may break users that were using JAX internals.
## jax 0.3.17 (Aug 31, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).

View File

@ -146,3 +146,5 @@ from jax import tree_util as tree_util
from jax import util as util
import jax.lib # TODO(phawkins): remove this export.
del jax._src

View File

@ -103,7 +103,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
device_assignment = np.arange(num_partitions * num_replicas)
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
use_spmd_partitioning = num_partitions > 1
compile_options = jax._src.lib.xla_bridge.get_compile_options(
compile_options = xla_bridge.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,

View File

@ -24,7 +24,7 @@ from jax import core
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_solver, xla_extension_version
import numpy as np
@ -550,6 +550,6 @@ def spsolve(data, indices, indptr, b, tol=1e-6, reorder=1):
An array with the same dtype and size as b representing the solution to
the sparse linear system.
"""
if jax._src.lib.xla_extension_version < 86:
if xla_extension_version < 86:
raise ValueError('spsolve requires jaxlib version 86 or above.')
return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)