From 90201ce2b743cf30922515dd658645079bb52642 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 8 Jan 2025 01:38:55 -0800 Subject: [PATCH] Removed leftover mentions of xmap from the code PiperOrigin-RevId: 713202387 --- ci/run_pytest_gpu.sh | 3 +-- docs/aot.md | 8 +------- jax/_src/interpreters/pxla.py | 2 +- jax/_src/mesh.py | 8 ++++---- 4 files changed, 7 insertions(+), 14 deletions(-) diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_gpu.sh index 7bc249278..416d985d3 100644 --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_gpu.sh @@ -56,6 +56,5 @@ echo "Running GPU tests..." "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ ---deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \ --deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \ ---deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric \ No newline at end of file +--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric diff --git a/docs/aot.md b/docs/aot.md index a5dd69a72..2dc4eadf3 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -73,14 +73,8 @@ see the {ref}`export` APIs. See the {mod}`jax.stages` documentation for more details on what functionality the lowering and compiled functions provide. -In place of `jax.jit` above, you can also `lower(...)` the result of -{func}`jax.pmap`, as well as `pjit` and `xmap` (from -{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In -each case, you can `compile()` the result similarly. - All optional arguments to `jit`---such as `static_argnums`---are respected in -the corresponding lowering, compilation, and execution. Again the same goes for -`pmap`, `pjit`, and `xmap`. +the corresponding lowering, compilation, and execution. In the example above, we can replace the arguments to `lower` with any objects that have `shape` and `dtype` attributes: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 2d9d9857f..ec2d52e48 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1179,7 +1179,7 @@ class InputsHandler: class ResultsHandler: - # `out_avals` is the `Array` global avals when using pjit or xmap. It is the + # `out_avals` is the `Array` global avals when using pjit. It is the # local one when using `pmap`. __slots__ = ("handlers", "out_shardings", "out_avals") diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 25fb2b38f..34f485684 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -76,7 +76,7 @@ class ResourceEnv(NamedTuple): @util.cache(max_size=128, trace_context_in_key=False) def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: if global_mesh.empty: - return global_mesh + return global_mesh is_local_device = np.vectorize( lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices) subcube_indices = [] @@ -96,9 +96,9 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: # subcube that hull will contain non-local devices. if not is_local_device[subcube_indices_tuple].all(): raise ValueError( - "When passing host local inputs to pjit or xmap, devices " - "connected to a single host must form a contiguous subcube of the " - "global device mesh") + "When passing host local inputs to pjit, devices connected to a single" + " host must form a contiguous subcube of the global device mesh" + ) return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)