diff --git a/jax/BUILD b/jax/BUILD index 25cc01fa0..70ed89e23 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -807,6 +807,7 @@ pytype_strict_library( ":core", ":jax", ":mlir", + ":sharding_impls", "//jax/_src/lib", ] + if_building_jaxlib([ "//jaxlib/mlir:ir", diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 4cf0017b9..2d3eb1edf 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -30,6 +30,7 @@ from jax._src.numpy.reductions import _moveaxis from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize from jax._src.util import canonicalize_axis, set_module +from jax._src import pjit import numpy as np @@ -53,8 +54,8 @@ def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Pri eqn = jaxpr.eqns[0] if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars): return None - elif (eqn.primitive == jax._src.pjit.pjit_p and - all(jax._src.pjit.is_unspecified(sharding) for sharding in + elif (eqn.primitive == pjit.pjit_p and + all(pjit.is_unspecified(sharding) for sharding in (*eqn.params['in_shardings'], *eqn.params['out_shardings']))): jaxpr = jaxpr.eqns[0].params['jaxpr'] else: diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index ae7413195..7d724d90a 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -34,6 +34,7 @@ from jax._src.lib import tpu from jax._src.lib import xla_client from jax._src.lib.mlir.dialects import hlo from jax._src.interpreters import mlir +from jax._src import sharding_impls from jax.interpreters import xla from jaxlib.mlir import ir from jaxlib.mlir.dialects import stablehlo @@ -184,7 +185,6 @@ def _tpu_custom_call_lowering( else: result_type = mlir.aval_to_ir_type(out_avals[0]) axis_context = ctx.module_context.axis_context - sharding_impls = jax._src.sharding_impls # pylint: disable=protected-access if isinstance(axis_context, sharding_impls.SPMDAxisContext): if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names): raise NotImplementedError( diff --git a/tests/api_test.py b/tests/api_test.py index e5fab2244..8483c3919 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -52,6 +52,7 @@ from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src import debugging from jax._src.ad_checkpoint import saved_residuals from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -6310,7 +6311,7 @@ class JaxprTest(jtu.JaxTestCase): jax.debug.print("{}", x) return x jaxpr = jax.make_jaxpr(f)(np.int32(0)) - self.assertEqual(jaxpr.eqns[0].primitive, jax._src.debugging.debug_callback_p) + self.assertEqual(jaxpr.eqns[0].primitive, debugging.debug_callback_p) self.assertStartsWith(str(jaxpr.eqns[0]), "debug_callback[", ) diff --git a/tests/profiler_test.py b/tests/profiler_test.py index a9b98f13d..325392bdd 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -28,6 +28,7 @@ import jax.numpy as jnp import jax.profiler from jax import config import jax._src.test_util as jtu +from jax._src import profiler try: import portpicker @@ -121,7 +122,7 @@ class ProfilerTest(unittest.TestCase): jnp.ones(jax.local_device_count()) ) finally: - fdo_profile = jax._src.profiler.stop_and_get_fdo_profile() + fdo_profile = profiler.stop_and_get_fdo_profile() if jtu.test_device_matches(["gpu"]) and jtu.is_device_cuda(): self.assertIn(b"copy", fdo_profile)