diff --git a/CHANGELOG.md b/CHANGELOG.md index d100b4096..ae03a5df4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. {jax-issue}`#7733`) is stable and public. See [the overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs for {mod}`jax.stages`. +* Breaking changes + * `jax._src` is no longer imported into the from the public `jax` namespace. + This may break users that were using JAX internals. ## jax 0.3.17 (Aug 31, 2022) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17). diff --git a/jax/__init__.py b/jax/__init__.py index 5271e73a8..75b54317e 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -146,3 +146,5 @@ from jax import tree_util as tree_util from jax import util as util import jax.lib # TODO(phawkins): remove this export. + +del jax._src diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 03e04ffc4..246da8ffd 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -103,7 +103,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase): device_assignment = np.arange(num_partitions * num_replicas) device_assignment = np.reshape(device_assignment, (-1, num_partitions)) use_spmd_partitioning = num_partitions > 1 - compile_options = jax._src.lib.xla_bridge.get_compile_options( + compile_options = xla_bridge.get_compile_options( num_replicas=num_replicas, num_partitions=num_partitions, device_assignment=device_assignment, diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index bbace7344..de08f56c3 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -24,7 +24,7 @@ from jax import core from jax.interpreters import mlir from jax.interpreters import xla -from jax._src.lib import gpu_solver +from jax._src.lib import gpu_solver, xla_extension_version import numpy as np @@ -550,6 +550,6 @@ def spsolve(data, indices, indptr, b, tol=1e-6, reorder=1): An array with the same dtype and size as b representing the solution to the sparse linear system. """ - if jax._src.lib.xla_extension_version < 86: + if xla_extension_version < 86: raise ValueError('spsolve requires jaxlib version 86 or above.') return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)