mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove jax._src deletion.
This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism. It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this: ``` from jax._src.lib import xla_bridge @mock.patch.object(xla_bridge, 'process_index') ... ``` A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is: ``` @mock.patch(f'{jax.process_index.__module__}.process_index') ... ``` However, this solution requires the `jax._src` be present in the JAX namespace. Ideally users wouldn't mock our internals at all, but that requires significantly more work. PiperOrigin-RevId: 512295203
This commit is contained in:
parent
0292f5d0a6
commit
b61d5d5654
@ -173,11 +173,4 @@ from jax._src.array import Shard as Shard
|
||||
|
||||
import jax.lib # TODO(phawkins): remove this export.
|
||||
|
||||
if hasattr(jax, '_src'):
|
||||
del jax._src
|
||||
else:
|
||||
from warnings import warn as _warn
|
||||
_warn("The jax module appears to have been reloaded within the python process. "
|
||||
"This is not well-supported and can cause unpredictable side-effects. "
|
||||
"For information see https://github.com/google/jax/issues/13857.")
|
||||
del _warn
|
||||
# trailer
|
||||
|
@ -4156,25 +4156,6 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, "applied to foo"):
|
||||
f_vjp(1.0, 1.0)
|
||||
|
||||
@unittest.skipIf(not sys.executable, "test requires sys.executable")
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def test_jax_reload_warning(self):
|
||||
# Regression test for https://github.com/google/jax/issues/13857
|
||||
should_not_warn = "import jax"
|
||||
should_warn = (
|
||||
"import jax;"
|
||||
"import importlib;"
|
||||
"importlib.reload(jax)")
|
||||
expected = "The jax module appears to have been reloaded within the python process"
|
||||
|
||||
result = subprocess.run([sys.executable, '-c', should_not_warn],
|
||||
check=True, capture_output=True)
|
||||
assert expected not in result.stderr.decode()
|
||||
|
||||
result = subprocess.run([sys.executable, '-c', should_warn],
|
||||
check=True, capture_output=True)
|
||||
assert expected in result.stderr.decode()
|
||||
|
||||
def test_shapedtypestruct_sharding_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
|
Loading…
x
Reference in New Issue
Block a user