Remove jax._src deletion.

This isn't a completely effective way to close off the JAX private namespace, since it's easy to work around via the module import mechanism.

It also prevents us from fixing users who are mocking JAX internals. Some users, e.g. t5x, have test code like this:

```
from jax._src.lib import xla_bridge

@mock.patch.object(xla_bridge, 'process_index')
...
```

A slightly cleaner solution that does not require importing the JAX internals and does not assume how the internals are laid out is:

```
@mock.patch(f'{jax.process_index.__module__}.process_index')
...
```

However, this solution requires the `jax._src` be present in the JAX namespace.

Ideally users wouldn't mock our internals at all, but that requires significantly more work.

PiperOrigin-RevId: 512295203
This commit is contained in:
Peter Hawkins 2023-02-25 07:17:18 -08:00 committed by jax authors
parent 0292f5d0a6
commit b61d5d5654
2 changed files with 1 additions and 27 deletions

View File

@ -173,11 +173,4 @@ from jax._src.array import Shard as Shard
import jax.lib # TODO(phawkins): remove this export.
if hasattr(jax, '_src'):
del jax._src
else:
from warnings import warn as _warn
_warn("The jax module appears to have been reloaded within the python process. "
"This is not well-supported and can cause unpredictable side-effects. "
"For information see https://github.com/google/jax/issues/13857.")
del _warn
# trailer

View File

@ -4156,25 +4156,6 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, "applied to foo"):
f_vjp(1.0, 1.0)
@unittest.skipIf(not sys.executable, "test requires sys.executable")
@jtu.skip_on_devices("gpu", "tpu")
def test_jax_reload_warning(self):
# Regression test for https://github.com/google/jax/issues/13857
should_not_warn = "import jax"
should_warn = (
"import jax;"
"import importlib;"
"importlib.reload(jax)")
expected = "The jax module appears to have been reloaded within the python process"
result = subprocess.run([sys.executable, '-c', should_not_warn],
check=True, capture_output=True)
assert expected not in result.stderr.decode()
result = subprocess.run([sys.executable, '-c', should_warn],
check=True, capture_output=True)
assert expected in result.stderr.decode()
def test_shapedtypestruct_sharding_error(self):
with self.assertRaisesRegex(
ValueError,