Removed leftover mentions of xmap from the code

PiperOrigin-RevId: 713202387
This commit is contained in:
Sergei Lebedev 2025-01-08 01:38:55 -08:00 committed by jax authors
parent 81db3219b7
commit 90201ce2b7
4 changed files with 7 additions and 14 deletions

View File

@ -56,6 +56,5 @@ echo "Running GPU tests..."
"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \
tests examples \
--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \
--deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric

View File

@ -73,14 +73,8 @@ see the {ref}`export` APIs.
See the {mod}`jax.stages` documentation for more details on what functionality
the lowering and compiled functions provide.
In place of `jax.jit` above, you can also `lower(...)` the result of
{func}`jax.pmap`, as well as `pjit` and `xmap` (from
{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In
each case, you can `compile()` the result similarly.
All optional arguments to `jit`---such as `static_argnums`---are respected in
the corresponding lowering, compilation, and execution. Again the same goes for
`pmap`, `pjit`, and `xmap`.
the corresponding lowering, compilation, and execution.
In the example above, we can replace the arguments to `lower` with any objects
that have `shape` and `dtype` attributes:

View File

@ -1179,7 +1179,7 @@ class InputsHandler:
class ResultsHandler:
# `out_avals` is the `Array` global avals when using pjit or xmap. It is the
# `out_avals` is the `Array` global avals when using pjit. It is the
# local one when using `pmap`.
__slots__ = ("handlers", "out_shardings", "out_avals")

View File

@ -76,7 +76,7 @@ class ResourceEnv(NamedTuple):
@util.cache(max_size=128, trace_context_in_key=False)
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
if global_mesh.empty:
return global_mesh
return global_mesh
is_local_device = np.vectorize(
lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
subcube_indices = []
@ -96,9 +96,9 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices_tuple].all():
raise ValueError(
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
"When passing host local inputs to pjit, devices connected to a single"
" host must form a contiguous subcube of the global device mesh"
)
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)