Import submodules from jax._src explicitly, instead of relying on import side-effects. It will lead to the missing x-refs in code search according to go/pywald-sawmill-analysis.

PiperOrigin-RevId: 604788105
This commit is contained in:
jax authors 2024-02-06 15:46:31 -08:00
parent ca77e5639f
commit 9b27d43e70
5 changed files with 9 additions and 5 deletions

View File

@ -807,6 +807,7 @@ pytype_strict_library(
":core",
":jax",
":mlir",
":sharding_impls",
"//jax/_src/lib",
] + if_building_jaxlib([
"//jaxlib/mlir:ir",

View File

@ -30,6 +30,7 @@ from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where
from jax._src.numpy.vectorize import vectorize
from jax._src.util import canonicalize_axis, set_module
from jax._src import pjit
import numpy as np
@ -53,8 +54,8 @@ def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Pri
eqn = jaxpr.eqns[0]
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
return None
elif (eqn.primitive == jax._src.pjit.pjit_p and
all(jax._src.pjit.is_unspecified(sharding) for sharding in
elif (eqn.primitive == pjit.pjit_p and
all(pjit.is_unspecified(sharding) for sharding in
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
jaxpr = jaxpr.eqns[0].params['jaxpr']
else:

View File

@ -34,6 +34,7 @@ from jax._src.lib import tpu
from jax._src.lib import xla_client
from jax._src.lib.mlir.dialects import hlo
from jax._src.interpreters import mlir
from jax._src import sharding_impls
from jax.interpreters import xla
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import stablehlo
@ -184,7 +185,6 @@ def _tpu_custom_call_lowering(
else:
result_type = mlir.aval_to_ir_type(out_avals[0])
axis_context = ctx.module_context.axis_context
sharding_impls = jax._src.sharding_impls # pylint: disable=protected-access
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(

View File

@ -52,6 +52,7 @@ from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src import debugging
from jax._src.ad_checkpoint import saved_residuals
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
@ -6310,7 +6311,7 @@ class JaxprTest(jtu.JaxTestCase):
jax.debug.print("{}", x)
return x
jaxpr = jax.make_jaxpr(f)(np.int32(0))
self.assertEqual(jaxpr.eqns[0].primitive, jax._src.debugging.debug_callback_p)
self.assertEqual(jaxpr.eqns[0].primitive, debugging.debug_callback_p)
self.assertStartsWith(str(jaxpr.eqns[0]), "debug_callback[", )

View File

@ -28,6 +28,7 @@ import jax.numpy as jnp
import jax.profiler
from jax import config
import jax._src.test_util as jtu
from jax._src import profiler
try:
import portpicker
@ -121,7 +122,7 @@ class ProfilerTest(unittest.TestCase):
jnp.ones(jax.local_device_count())
)
finally:
fdo_profile = jax._src.profiler.stop_and_get_fdo_profile()
fdo_profile = profiler.stop_and_get_fdo_profile()
if jtu.test_device_matches(["gpu"]) and jtu.is_device_cuda():
self.assertIn(b"copy", fdo_profile)