mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove jax._src from JAX namespace.
This is a JAX-internal name and not subject to any deprecation policy. Please avoid the use of JAX-internal functions outside JAX. PiperOrigin-RevId: 473243243
This commit is contained in:
parent
bc59bd1ddc
commit
40c80d7d0a
@ -15,6 +15,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
{jax-issue}`#7733`) is stable and public. See [the
|
{jax-issue}`#7733`) is stable and public. See [the
|
||||||
overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs
|
overview](https://jax.readthedocs.io/en/latest/aot.html) and the API docs
|
||||||
for {mod}`jax.stages`.
|
for {mod}`jax.stages`.
|
||||||
|
* Breaking changes
|
||||||
|
* `jax._src` is no longer imported into the from the public `jax` namespace.
|
||||||
|
This may break users that were using JAX internals.
|
||||||
|
|
||||||
## jax 0.3.17 (Aug 31, 2022)
|
## jax 0.3.17 (Aug 31, 2022)
|
||||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).
|
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).
|
||||||
|
@ -146,3 +146,5 @@ from jax import tree_util as tree_util
|
|||||||
from jax import util as util
|
from jax import util as util
|
||||||
|
|
||||||
import jax.lib # TODO(phawkins): remove this export.
|
import jax.lib # TODO(phawkins): remove this export.
|
||||||
|
|
||||||
|
del jax._src
|
||||||
|
@ -103,7 +103,7 @@ class ShardedJitHloTest(tf_test_util.JaxToTfTestCase):
|
|||||||
device_assignment = np.arange(num_partitions * num_replicas)
|
device_assignment = np.arange(num_partitions * num_replicas)
|
||||||
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
|
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
|
||||||
use_spmd_partitioning = num_partitions > 1
|
use_spmd_partitioning = num_partitions > 1
|
||||||
compile_options = jax._src.lib.xla_bridge.get_compile_options(
|
compile_options = xla_bridge.get_compile_options(
|
||||||
num_replicas=num_replicas,
|
num_replicas=num_replicas,
|
||||||
num_partitions=num_partitions,
|
num_partitions=num_partitions,
|
||||||
device_assignment=device_assignment,
|
device_assignment=device_assignment,
|
||||||
|
@ -24,7 +24,7 @@ from jax import core
|
|||||||
from jax.interpreters import mlir
|
from jax.interpreters import mlir
|
||||||
from jax.interpreters import xla
|
from jax.interpreters import xla
|
||||||
|
|
||||||
from jax._src.lib import gpu_solver
|
from jax._src.lib import gpu_solver, xla_extension_version
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -550,6 +550,6 @@ def spsolve(data, indices, indptr, b, tol=1e-6, reorder=1):
|
|||||||
An array with the same dtype and size as b representing the solution to
|
An array with the same dtype and size as b representing the solution to
|
||||||
the sparse linear system.
|
the sparse linear system.
|
||||||
"""
|
"""
|
||||||
if jax._src.lib.xla_extension_version < 86:
|
if xla_extension_version < 86:
|
||||||
raise ValueError('spsolve requires jaxlib version 86 or above.')
|
raise ValueError('spsolve requires jaxlib version 86 or above.')
|
||||||
return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)
|
return spsolve_p.bind(data, indices, indptr, b, tol=tol, reorder=reorder)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user