mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Removed leftover mentions of xmap from the code
PiperOrigin-RevId: 713202387
This commit is contained in:
parent
81db3219b7
commit
90201ce2b7
@ -56,6 +56,5 @@ echo "Running GPU tests..."
|
||||
"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \
|
||||
tests examples \
|
||||
--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
|
||||
--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \
|
||||
--deselect=tests/multiprocess_gpu_test.py::MultiProcessGpuTest::test_distributed_jax_visible_devices \
|
||||
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
|
||||
--deselect=tests/compilation_cache_test.py::CompilationCacheTest::test_task_using_cache_metric
|
||||
|
@ -73,14 +73,8 @@ see the {ref}`export` APIs.
|
||||
See the {mod}`jax.stages` documentation for more details on what functionality
|
||||
the lowering and compiled functions provide.
|
||||
|
||||
In place of `jax.jit` above, you can also `lower(...)` the result of
|
||||
{func}`jax.pmap`, as well as `pjit` and `xmap` (from
|
||||
{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In
|
||||
each case, you can `compile()` the result similarly.
|
||||
|
||||
All optional arguments to `jit`---such as `static_argnums`---are respected in
|
||||
the corresponding lowering, compilation, and execution. Again the same goes for
|
||||
`pmap`, `pjit`, and `xmap`.
|
||||
the corresponding lowering, compilation, and execution.
|
||||
|
||||
In the example above, we can replace the arguments to `lower` with any objects
|
||||
that have `shape` and `dtype` attributes:
|
||||
|
@ -1179,7 +1179,7 @@ class InputsHandler:
|
||||
|
||||
|
||||
class ResultsHandler:
|
||||
# `out_avals` is the `Array` global avals when using pjit or xmap. It is the
|
||||
# `out_avals` is the `Array` global avals when using pjit. It is the
|
||||
# local one when using `pmap`.
|
||||
__slots__ = ("handlers", "out_shardings", "out_avals")
|
||||
|
||||
|
@ -76,7 +76,7 @@ class ResourceEnv(NamedTuple):
|
||||
@util.cache(max_size=128, trace_context_in_key=False)
|
||||
def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
||||
if global_mesh.empty:
|
||||
return global_mesh
|
||||
return global_mesh
|
||||
is_local_device = np.vectorize(
|
||||
lambda d: d.process_index == process_index, otypes=[bool])(global_mesh.devices)
|
||||
subcube_indices = []
|
||||
@ -96,9 +96,9 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
|
||||
# subcube that hull will contain non-local devices.
|
||||
if not is_local_device[subcube_indices_tuple].all():
|
||||
raise ValueError(
|
||||
"When passing host local inputs to pjit or xmap, devices "
|
||||
"connected to a single host must form a contiguous subcube of the "
|
||||
"global device mesh")
|
||||
"When passing host local inputs to pjit, devices connected to a single"
|
||||
" host must form a contiguous subcube of the global device mesh"
|
||||
)
|
||||
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user