mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Import submodules from jax._src explicitly, instead of relying on import side-effects. It will lead to the missing x-refs in code search according to go/pywald-sawmill-analysis.
PiperOrigin-RevId: 604788105
This commit is contained in:
parent
ca77e5639f
commit
9b27d43e70
@ -807,6 +807,7 @@ pytype_strict_library(
|
||||
":core",
|
||||
":jax",
|
||||
":mlir",
|
||||
":sharding_impls",
|
||||
"//jax/_src/lib",
|
||||
] + if_building_jaxlib([
|
||||
"//jaxlib/mlir:ir",
|
||||
|
@ -30,6 +30,7 @@ from jax._src.numpy.reductions import _moveaxis
|
||||
from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.util import canonicalize_axis, set_module
|
||||
from jax._src import pjit
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -53,8 +54,8 @@ def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Pri
|
||||
eqn = jaxpr.eqns[0]
|
||||
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
|
||||
return None
|
||||
elif (eqn.primitive == jax._src.pjit.pjit_p and
|
||||
all(jax._src.pjit.is_unspecified(sharding) for sharding in
|
||||
elif (eqn.primitive == pjit.pjit_p and
|
||||
all(pjit.is_unspecified(sharding) for sharding in
|
||||
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
|
||||
jaxpr = jaxpr.eqns[0].params['jaxpr']
|
||||
else:
|
||||
|
@ -34,6 +34,7 @@ from jax._src.lib import tpu
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src import sharding_impls
|
||||
from jax.interpreters import xla
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import stablehlo
|
||||
@ -184,7 +185,6 @@ def _tpu_custom_call_lowering(
|
||||
else:
|
||||
result_type = mlir.aval_to_ir_type(out_avals[0])
|
||||
axis_context = ctx.module_context.axis_context
|
||||
sharding_impls = jax._src.sharding_impls # pylint: disable=protected-access
|
||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
|
||||
raise NotImplementedError(
|
||||
|
@ -52,6 +52,7 @@ from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import debugging
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -6310,7 +6311,7 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
jax.debug.print("{}", x)
|
||||
return x
|
||||
jaxpr = jax.make_jaxpr(f)(np.int32(0))
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, jax._src.debugging.debug_callback_p)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, debugging.debug_callback_p)
|
||||
self.assertStartsWith(str(jaxpr.eqns[0]), "debug_callback[", )
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
from jax import config
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src import profiler
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
@ -121,7 +122,7 @@ class ProfilerTest(unittest.TestCase):
|
||||
jnp.ones(jax.local_device_count())
|
||||
)
|
||||
finally:
|
||||
fdo_profile = jax._src.profiler.stop_and_get_fdo_profile()
|
||||
fdo_profile = profiler.stop_and_get_fdo_profile()
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.is_device_cuda():
|
||||
self.assertIn(b"copy", fdo_profile)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user