mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Delete xmap
and the jax.experimental.maps
module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
This commit is contained in:
parent
d9a7cb4490
commit
0d5dae09ff
@ -10,6 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.31
|
||||
|
||||
* Deletion
|
||||
* xmap has been deleted. Please use {func}`shard_map` as the replacement.
|
||||
|
||||
* Changes
|
||||
* The minimum Python version is now 3.10. 3.10 will remain the minimum
|
||||
supported version until July 2025.
|
||||
|
@ -563,8 +563,7 @@ As XLA does not have enough knowledge about the custom functions to shard input
|
||||
To avoid this duplication, we can:
|
||||
- [custom_partitioning](https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html): to make it behave like all native JAX operations (but more complicated)
|
||||
- Use manual sharding
|
||||
- [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html): the new replacement for xmap
|
||||
- [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) (now deprecated)
|
||||
- [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html)
|
||||
|
||||
This example demonstrates the use of custom_partitioning.
|
||||
|
||||
|
@ -20,7 +20,6 @@ This section contains examples and tutorials on more advanced topics, such as Mu
|
||||
notebooks/Distributed_arrays_and_automatic_parallelization
|
||||
notebooks/shard_map
|
||||
distributed_data_loading
|
||||
notebooks/xmap_tutorial
|
||||
|
||||
.. toctree::
|
||||
:caption: Automatic Differentiation
|
||||
|
@ -214,7 +214,6 @@ nb_execution_excludepatterns = [
|
||||
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
|
||||
'jep/9407-type-promotion.*',
|
||||
# TODO(jakevdp): enable execution on the following if possible:
|
||||
'notebooks/xmap_tutorial.*',
|
||||
'notebooks/Distributed_arrays_and_automatic_parallelization.*',
|
||||
'notebooks/autodiff_remat.*',
|
||||
# Requires accelerators
|
||||
|
@ -138,7 +138,7 @@ jax.grad(f)(1.)
|
||||
|
||||
#### Printing in other transformations
|
||||
|
||||
`jax.debug.print` also works in other transformations like `xmap` and `pjit`.
|
||||
`jax.debug.print` also works in other transformations like `pjit`.
|
||||
|
||||
### More control with `jax.debug.callback`
|
||||
|
||||
|
@ -1,12 +0,0 @@
|
||||
``jax.experimental.maps`` module
|
||||
================================
|
||||
|
||||
.. automodule:: jax.experimental.maps
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
xmap
|
@ -17,7 +17,6 @@ Experimental Modules
|
||||
jax.experimental.array_api
|
||||
jax.experimental.checkify
|
||||
jax.experimental.host_callback
|
||||
jax.experimental.maps
|
||||
jax.experimental.pjit
|
||||
jax.experimental.sparse
|
||||
jax.experimental.jet
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
14
jax/BUILD
14
jax/BUILD
@ -220,7 +220,6 @@ py_library_providing_imports_info(
|
||||
"_src/interpreters/ad.py",
|
||||
"_src/interpreters/batching.py",
|
||||
"_src/interpreters/pxla.py",
|
||||
"_src/maps.py",
|
||||
"_src/pjit.py",
|
||||
"_src/prng.py",
|
||||
"_src/public_test_util.py",
|
||||
@ -262,8 +261,6 @@ py_library_providing_imports_info(
|
||||
],
|
||||
) + [
|
||||
"experimental/attrs.py",
|
||||
# until new parallelism APIs are moved out of experimental
|
||||
"experimental/maps.py",
|
||||
"experimental/pjit.py",
|
||||
"experimental/multihost_utils.py",
|
||||
"experimental/shard_map.py",
|
||||
@ -1041,17 +1038,6 @@ pytype_library(
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
# TODO(apaszke): Remove this target
|
||||
pytype_library(
|
||||
name = "maps",
|
||||
srcs = ["experimental/maps.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jax",
|
||||
":mesh",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(apaszke): Remove this target
|
||||
pytype_library(
|
||||
name = "pjit",
|
||||
|
@ -180,7 +180,6 @@ import jax.experimental.compilation_cache.compilation_cache as _ccache
|
||||
del _ccache
|
||||
|
||||
from jax._src.deprecations import register as _register_deprecation
|
||||
_register_deprecation("jax-experimental-maps-module")
|
||||
_register_deprecation('jax-scipy-beta-args')
|
||||
_register_deprecation('tracer-hash')
|
||||
del _register_deprecation
|
||||
|
@ -192,7 +192,7 @@ class ArrayImpl(basearray.Array):
|
||||
# Don't rearrange if skip_checks is enabled because this assumes that the
|
||||
# input buffers are already arranged properly. This usually happens when
|
||||
# Array's are created as output of a JAX transformation
|
||||
# (like pjit, xmap, etc).
|
||||
# (like pjit, etc).
|
||||
if not _skip_checks or config.enable_checks.value:
|
||||
self._check_and_rearrange()
|
||||
|
||||
|
1877
jax/_src/maps.py
1877
jax/_src/maps.py
File diff suppressed because it is too large
Load Diff
@ -68,7 +68,7 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
|
||||
|
||||
|
||||
def hashed_index(x) -> int:
|
||||
# This works for both `pjit`/`xmap` indices and `pmap` indices (which might
|
||||
# This works for both `pjit` indices and `pmap` indices (which might
|
||||
# have an integer instead of a slice).
|
||||
assert all(v.step is None for v in x if isinstance(v, slice))
|
||||
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
|
||||
@ -1123,8 +1123,8 @@ class SPMDAxisContext:
|
||||
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
|
||||
|
||||
This includes the mesh that will later by used to execute this computation,
|
||||
as well as a set of mesh axes that are currently (e.g. because the current lowering
|
||||
is invoked inside an xmap) lowered in the MANUAL sharding mode.
|
||||
as well as a set of mesh axes that are currently lowered in the MANUAL
|
||||
sharding mode.
|
||||
"""
|
||||
mesh: mesh_lib.Mesh
|
||||
manual_axes: frozenset[MeshAxisName] = frozenset()
|
||||
|
@ -201,13 +201,13 @@ def _tpu_custom_call_lowering(
|
||||
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
|
||||
raise NotImplementedError(
|
||||
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
|
||||
" call in a shard_map or xmap."
|
||||
" call in a shard_map."
|
||||
)
|
||||
elif isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
if axis_context.num_devices != 1:
|
||||
raise NotImplementedError(
|
||||
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
|
||||
" call in a shard_map or xmap."
|
||||
" call in a shard_map."
|
||||
)
|
||||
elif config.has_communication:
|
||||
raise NotImplementedError(
|
||||
|
@ -26,7 +26,7 @@ is invoked from the other framework.
|
||||
The native serialization mode has several advantages:
|
||||
|
||||
* supports virtually all operations supported by native execution, e.g.,
|
||||
`xmap`, `shard_map`, `pmap`, parallel collective operations, and all
|
||||
`shard_map`, `pmap`, parallel collective operations, and all
|
||||
primitives at all data types.
|
||||
* uses standard native JAX code paths for lowering, and thus it is easier
|
||||
to trust that the semantics and performance stays faithful to the native
|
||||
|
@ -78,7 +78,7 @@ def call_tf(
|
||||
function must return the same type of results.
|
||||
|
||||
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
|
||||
or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
|
||||
or :func:`jax.pmap`, or a control-flow primitive) then
|
||||
``callable_tf`` will be compiled with ``tf.function(callable_tf,
|
||||
jit_compile=True)``
|
||||
and the resulting XLA computation will be embedded in JAX's XLA computation.
|
||||
|
@ -51,7 +51,6 @@ from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import maps
|
||||
from jax._src import mesh
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
@ -1484,7 +1483,7 @@ class TensorFlowTrace(core.Trace):
|
||||
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
assert False, f"Encountered unexpected primitive {p}"
|
||||
|
||||
for unexpected in [core.call_p, maps.xmap_p]:
|
||||
for unexpected in [core.call_p]:
|
||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||
|
||||
tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \
|
||||
|
@ -33,7 +33,6 @@ from jax import numpy as jnp
|
||||
from jax import sharding
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src.maps import xmap
|
||||
from jax._src import source_info_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -1427,8 +1426,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
"none",
|
||||
"jit",
|
||||
"pjit", "pjit_in_shardings_None", "pjit_in_shardings_P",
|
||||
"pjit_in_shardings_Sharding",
|
||||
"shard_map", "xmap", "pmap"]
|
||||
"pjit_in_shardings_Sharding", "shard_map", "pmap"]
|
||||
for transform2 in (
|
||||
["none", "pjit_in_shardings_None", "pjit_in_shardings_P",
|
||||
"pjit_in_shardings_Sharding"]
|
||||
@ -1483,8 +1481,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
shard_map=(
|
||||
shard_map(func, mesh, in_specs=(P("a", None),),
|
||||
out_specs=P("a", None))),
|
||||
xmap=xmap(func, in_axes=({0: 'axis'},),
|
||||
out_axes={0: 'axis'}, axis_resources={'axis': 'a'}),
|
||||
pmap=jax.pmap(func, in_axes=0, out_axes=0),
|
||||
)[transform]
|
||||
return transformed_func
|
||||
@ -1492,12 +1488,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
transformed1_func = apply_transform(
|
||||
(func_shard_map if transform1 == "shard_map" else func),
|
||||
transform1)
|
||||
assert transform2 not in ["xmap", "shard_map"]
|
||||
assert transform2 not in ["shard_map"]
|
||||
transformed2_func = apply_transform(transformed1_func, transform2)
|
||||
|
||||
if transform1 == "xmap" and transform2 in ["pjit", "none"]:
|
||||
raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes")
|
||||
|
||||
if transform1 == "pmap":
|
||||
x = x.reshape((1, -1)) # Since we use 1 device
|
||||
if not nullary:
|
||||
|
@ -31,7 +31,6 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import maps # Needed for config flags.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -33,7 +33,6 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
from jax._src import maps
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax import lax
|
||||
|
@ -1,49 +0,0 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
from jax._src import deprecations
|
||||
from jax._src.maps import (
|
||||
AxisName as AxisName,
|
||||
ResourceSet as ResourceSet,
|
||||
SerialLoop as SerialLoop,
|
||||
_prepare_axes as _prepare_axes,
|
||||
make_xmap_callable as make_xmap_callable,
|
||||
serial_loop as serial_loop,
|
||||
xmap_p as xmap_p,
|
||||
xmap as xmap,
|
||||
)
|
||||
from jax._src.mesh import (
|
||||
EMPTY_ENV as EMPTY_ENV,
|
||||
ResourceEnv as ResourceEnv,
|
||||
thread_resources as thread_resources,
|
||||
)
|
||||
|
||||
# Added March 7, 2024.
|
||||
_msg = (
|
||||
"jax.experimental.maps and jax.experimental.maps.xmap are deprecated and"
|
||||
" will be removed in a future release. Use jax.experimental.shard_map or"
|
||||
" jax.vmap with the spmd_axis_name argument for expressing SPMD"
|
||||
" device-parallel computations. Please file an issue on"
|
||||
" https://github.com/google/jax/issues if neither"
|
||||
" jax.experimental.shard_map nor jax.vmap are suitable for your use case."
|
||||
)
|
||||
|
||||
if deprecations.is_accelerated("jax-experimental-maps-module"):
|
||||
raise ImportError(_msg)
|
||||
else:
|
||||
warnings.warn(_msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
del deprecations, warnings, _msg
|
24
tests/BUILD
24
tests/BUILD
@ -216,30 +216,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "xmap_test",
|
||||
srcs = ["xmap_test.py"],
|
||||
backend_tags = {
|
||||
"gpu": [
|
||||
"noasan", # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
|
||||
],
|
||||
"tpu": [
|
||||
"noasan", # Times out.
|
||||
"nomsan", # Times out.
|
||||
"notsan", # Times out.
|
||||
],
|
||||
},
|
||||
shard_count = {
|
||||
"cpu": 10,
|
||||
"gpu": 4,
|
||||
"tpu": 4,
|
||||
},
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:maps",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "memories_test",
|
||||
srcs = ["memories_test.py"],
|
||||
|
@ -41,7 +41,6 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.compilation_cache_interface import CacheInterface
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
@ -235,25 +234,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
|
||||
f(x, x + 1)
|
||||
self.assertEqual(count_cache_items(), 2)
|
||||
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_xmap(self):
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
devices = np.array(jax.local_devices()[:2])
|
||||
if devices.size < 2:
|
||||
raise SkipTest("Test requires 2 devices")
|
||||
x = np.arange(8, dtype=np.int64).reshape((2, 2, 2))
|
||||
xmap(
|
||||
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
|
||||
)(x)
|
||||
self.assertEqual(count_cache_items(), 1)
|
||||
x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
|
||||
xmap(
|
||||
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
|
||||
)(x)
|
||||
self.assertEqual(count_cache_items(), 2)
|
||||
|
||||
def test_cache_write_warning(self):
|
||||
f = jit(lambda x: x * x)
|
||||
|
||||
|
@ -24,7 +24,6 @@ from jax._src import api
|
||||
from jax._src import test_util as jtu
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import pjit
|
||||
from jax._src.maps import xmap
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -114,28 +113,6 @@ class DebugNaNsTest(jtu.JaxTestCase):
|
||||
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
|
||||
ans.block_until_ready()
|
||||
|
||||
@jtu.ignore_warning(message=".*is an experimental.*")
|
||||
def testXmap(self):
|
||||
|
||||
f = xmap(
|
||||
lambda x: 0. / x,
|
||||
in_axes=["i"],
|
||||
out_axes=["i"],
|
||||
axis_resources={"i": "x"})
|
||||
|
||||
with jax.sharding.Mesh(np.array(jax.local_devices()[:1]), ('x',)):
|
||||
with self.assertRaisesRegex(
|
||||
FloatingPointError,
|
||||
r"invalid value \(nan\) encountered in xmap"):
|
||||
ans = f(jnp.array([0.]))
|
||||
ans.block_until_ready()
|
||||
|
||||
if jax.device_count() >= 2:
|
||||
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
|
||||
with self.assertRaises(FloatingPointError):
|
||||
ans = f(jnp.array([1., 0.]))
|
||||
ans.block_until_ready()
|
||||
|
||||
@jtu.ignore_warning(message=".*is an experimental.*")
|
||||
def testPjit(self):
|
||||
if jax.device_count() < 2:
|
||||
|
@ -26,7 +26,6 @@ from jax._src import ad_checkpoint
|
||||
from jax._src import debugging
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.maps import xmap
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -786,40 +785,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
"[ 3 4 5 6 7 8 9 10]\n"
|
||||
"[ 4 5 6 7 8 9 10 11]\n")
|
||||
|
||||
def test_unordered_print_of_pjit_of_xmap(self):
|
||||
def f(x):
|
||||
def foo(x):
|
||||
idx = lax.axis_index('foo')
|
||||
debug_print("{idx}: {x}", idx=idx, x=x)
|
||||
return jnp.mean(x, axis=['foo'])
|
||||
out = xmap(foo, in_axes=['foo'], out_axes=[...])(x)
|
||||
debug_print("Out: {}", out)
|
||||
return out
|
||||
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
|
||||
in_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
|
||||
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
f = pjit.pjit(f, in_shardings=in_spec, out_shardings=out_spec)
|
||||
with mesh:
|
||||
with jtu.capture_stdout() as output:
|
||||
f(jnp.arange(8, dtype=jnp.int32) * 2)
|
||||
lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",
|
||||
"7: 14", "Out: 7.0", ""]
|
||||
jax.effects_barrier()
|
||||
|
||||
self._assertLinesEqual(output(), "\n".join(lines))
|
||||
|
||||
def test_unordered_print_with_xmap(self):
|
||||
def f(x):
|
||||
debug_print("{}", x, ordered=False)
|
||||
f = xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
|
||||
axis_resources={'a': 'dev'})
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
|
||||
with jtu.capture_stdout() as output:
|
||||
f(np.arange(40))
|
||||
jax.effects_barrier()
|
||||
lines = [f"{i}\n" for i in range(40)]
|
||||
self._assertLinesEqual(output(), "".join(lines))
|
||||
|
||||
def test_unordered_print_works_in_pmap_of_while(self):
|
||||
if jax.device_count() < 2:
|
||||
raise unittest.SkipTest("Test requires >= 2 devices.")
|
||||
|
@ -31,7 +31,6 @@ from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.maps import xmap
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -265,15 +264,6 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
|
||||
r"Ordered effects not supported for map primitives: \[.*\]"):
|
||||
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_xmap_inherits_effects(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
effect_p.bind(effect=bar_effect)
|
||||
return x
|
||||
f = xmap(f, in_axes=['a'], out_axes=['a'])
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
|
||||
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
|
||||
|
||||
def test_pjit_inherits_effects(self):
|
||||
def f(x):
|
||||
effect_p.bind(effect=foo_effect)
|
||||
|
@ -41,7 +41,6 @@ import jax.scipy as jsp
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax.control_flow import for_loop
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.maps import xmap
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -2710,19 +2709,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
lax.scan(side_effecting_scan, None, jnp.ones((2, 2)))
|
||||
lst[0] += 1
|
||||
|
||||
def test_while_loop_fixed_point_with_nested_named_axes(self):
|
||||
def f(x):
|
||||
z = x + lax.axis_index('a').astype(x.dtype)
|
||||
y = x + lax.axis_index('b').astype(x.dtype)
|
||||
def cond(carry):
|
||||
i, x = carry
|
||||
return x < 5
|
||||
def body(carry):
|
||||
i, x = carry
|
||||
return i + 1, x + lax.psum(y, 'b')
|
||||
return lax.while_loop(cond, body, (0, z))[1]
|
||||
xmap(f, axis_sizes=dict(a=2, b=10), out_axes=(['a']), in_axes={})(1.)
|
||||
|
||||
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
|
||||
def f(i, x):
|
||||
def cond(carry):
|
||||
|
@ -266,7 +266,6 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
|
||||
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
|
||||
"Slurm environment with at least two nodes needed!")
|
||||
@jtu.pytest_mark_if_available('SlurmMultiNodeGpuTest')
|
||||
@jtu.with_config(experimental_xmap_spmd_lowering=True)
|
||||
class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
|
||||
|
||||
def sorted_devices(self):
|
||||
|
@ -35,7 +35,6 @@ from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax.lax import with_sharding_constraint
|
||||
@ -51,11 +50,9 @@ from jax._src.sharding_impls import (
|
||||
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
|
||||
SingleDeviceSharding, parse_flatten_op_sharding)
|
||||
import jax._src.pjit as pjit_lib
|
||||
from jax._src.maps import xmap
|
||||
from jax._src.pjit import pjit, pjit_p
|
||||
from jax._src.pjit import pjit
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -69,7 +66,6 @@ _exit_stack = contextlib.ExitStack()
|
||||
|
||||
def setUpModule():
|
||||
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
|
||||
_exit_stack.enter_context(jtu.global_config_context(experimental_xmap_spmd_lowering=True))
|
||||
|
||||
def tearDownModule():
|
||||
_exit_stack.close()
|
||||
@ -749,25 +745,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertListEqual(op.tile_assignment_devices(), [0, 1])
|
||||
self.assertFalse(op_shardings.is_op_sharding_replicated(op))
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testShardingInXMap(self):
|
||||
h = pjit(lambda x: x, in_shardings=P('x'), out_shardings=None)
|
||||
f = xmap(lambda x: h(x * 2), in_axes=['i', ...], out_axes=['i', ...],
|
||||
axis_resources={'i': 'y'})
|
||||
x = jnp.arange(16).reshape((4, 4))
|
||||
rule = mlir._lowerings[pjit_p]
|
||||
test_rule_called = False
|
||||
def _test_rule(*args, **kwargs):
|
||||
nonlocal test_rule_called
|
||||
test_rule_called = True
|
||||
return rule(*args, **kwargs)
|
||||
try:
|
||||
mlir._lowerings[pjit_p] = _test_rule
|
||||
f(x)
|
||||
self.assertTrue(test_rule_called)
|
||||
finally:
|
||||
mlir._lowerings[pjit_p] = rule
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testLowerWithDuckTyping(self):
|
||||
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
|
||||
@ -3299,36 +3276,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
jax.device_put(2., NamedSharding(mesh, P())) # doesn't crash
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def test_jit_nested_xmap_lower_arg_info(self):
|
||||
def f(x, y, *args):
|
||||
out = xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
|
||||
axis_resources={'i': 'y'})(jnp.arange(8.))
|
||||
return y['hi'] + args[1], out
|
||||
|
||||
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
||||
{'hi': 1.}, {'hi': 2.}, 3., 4.)
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
# TODO(yashkatariya): Add keep_unused support to lower_mesh_computation
|
||||
# and then uncomment the below line.
|
||||
# self.assertNotIn("args[0]", hlo_str)
|
||||
self.assertIn("args[1]", hlo_str)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def test_jit_nested_xmap_lower_result_info(self):
|
||||
def f(x, y, z):
|
||||
_ = xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
|
||||
axis_resources={'i': 'y'})(jnp.arange(8.))
|
||||
return {'a': x, 'b': [y]}
|
||||
|
||||
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
||||
1., (2.,), [3.])
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||
|
||||
def test_with_sharding_constraint_with_two_meshes(self):
|
||||
if jax.device_count() < 4:
|
||||
self.skipTest("Requires more than 4 devices.")
|
||||
@ -4481,72 +4428,6 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, error):
|
||||
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testInputShardsXMapAxis(self):
|
||||
spec = P('x')
|
||||
f = xmap(
|
||||
pjit(lambda x: x + 2, in_shardings=spec, out_shardings=None),
|
||||
in_axes=['i', ...],
|
||||
out_axes=['i', ...],
|
||||
axis_resources={'i': 'x'},
|
||||
)
|
||||
x = jnp.arange(4).reshape((2, 2))
|
||||
error = (r"pjit input has an axis resources specification of " +
|
||||
spec_regex(spec) + r" that uses one or more "
|
||||
"mesh axes already used by "
|
||||
r"xmap to partition a named axis appearing in its named_shape \(both "
|
||||
r"use mesh axes `x`\)")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
f(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testOutputShardsXMapAxis(self):
|
||||
spec = P('x')
|
||||
f = xmap(
|
||||
pjit(lambda x: x + 2, in_shardings=None, out_shardings=spec),
|
||||
in_axes=['i', ...],
|
||||
out_axes=['i', ...],
|
||||
axis_resources={'i': 'x'},
|
||||
)
|
||||
x = jnp.arange(4).reshape((2, 2))
|
||||
error = (r"pjit output has an axis resources specification of " +
|
||||
spec_regex(spec) + r" that uses one or more "
|
||||
"mesh axes already used by "
|
||||
r"xmap to partition a named axis appearing in its named_shape \(both "
|
||||
r"use mesh axes `x`\)")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
f(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testConstraintShardsXMapAxis(self):
|
||||
spec = P('x')
|
||||
f = xmap(lambda x: with_sharding_constraint(x, spec),
|
||||
in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
|
||||
x = jnp.arange(4).reshape((2, 2))
|
||||
error = (r"with_sharding_constraint input has an axis resources specification of " +
|
||||
spec_regex(spec) + r" that uses one or more "
|
||||
"mesh axes already used by "
|
||||
r"xmap to partition a named axis appearing in its named_shape \(both "
|
||||
r"use mesh axes `x`\)")
|
||||
with self.assertRaisesRegex(JAXTypeError, error):
|
||||
f(x)
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
def testCatchesInnerXMapErrors(self):
|
||||
f = pjit(
|
||||
xmap(
|
||||
lambda x, y: x,
|
||||
in_axes=(['i'], ['j']),
|
||||
out_axes=['i', 'j'],
|
||||
axis_resources={'i': 'x', 'j': 'x'},
|
||||
),
|
||||
in_shardings=None,
|
||||
out_shardings=None,
|
||||
)
|
||||
x = jnp.arange(4)
|
||||
with self.assertRaises(JAXTypeError):
|
||||
f(x, x)
|
||||
|
||||
def testEmptyMesh(self):
|
||||
out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
|
||||
self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0]))
|
||||
|
@ -94,10 +94,6 @@ def args_slicer(args, bdims):
|
||||
ignore_jit_of_pmap_warning = partial(
|
||||
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
||||
|
||||
ignore_xmap_warning = partial(
|
||||
jtu.ignore_warning, message=".*is an experimental.*")
|
||||
|
||||
|
||||
def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
|
||||
devices=None, sharded_dim_size=None):
|
||||
if input_data is None:
|
||||
@ -1950,7 +1946,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
indices = np.array([[[2], [1]], [[0], [0]]])
|
||||
mapped_fn(indices) # doesn't crash
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testPdotBasic(self):
|
||||
num_devices = jax.device_count()
|
||||
|
||||
|
@ -26,11 +26,8 @@ from jax import lax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import maps
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.maps import xmap
|
||||
from jax.experimental import io_callback
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -739,46 +736,6 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
out = f(jnp.arange(float(jax.local_device_count())))
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
|
||||
|
||||
def test_can_pjit_pure_callback_under_hard_xmap(self):
|
||||
|
||||
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
|
||||
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
|
||||
|
||||
spmd_lowering = maps.SPMD_LOWERING.value
|
||||
spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
|
||||
config.update('experimental_xmap_spmd_lowering', True)
|
||||
config.update('experimental_xmap_spmd_lowering_manual', True)
|
||||
try:
|
||||
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
|
||||
|
||||
spec = jax.sharding.PartitionSpec('x')
|
||||
|
||||
def f(x):
|
||||
axis_resources = {v: v for v in mesh.axis_names}
|
||||
return xmap(
|
||||
lambda x: jax.pure_callback(np.sin, x, x),
|
||||
in_axes=(('x',),),
|
||||
out_axes=('x',),
|
||||
axis_resources=axis_resources,
|
||||
axis_sizes=mesh.shape,
|
||||
)(x)
|
||||
|
||||
def without_xmap_f(x):
|
||||
return jax.pure_callback(np.sin, x, x)
|
||||
|
||||
with mesh:
|
||||
inp = jnp.arange(float(jax.local_device_count()))
|
||||
out = pjit.pjit(f, in_shardings=spec, out_shardings=spec)(inp)
|
||||
np.testing.assert_allclose(
|
||||
out, np.sin(np.arange(jax.local_device_count()))
|
||||
)
|
||||
finally:
|
||||
config.update('experimental_xmap_spmd_lowering', spmd_lowering)
|
||||
config.update(
|
||||
'experimental_xmap_spmd_lowering_manual',
|
||||
spmd_manual_lowering,
|
||||
)
|
||||
|
||||
def test_cant_take_grad_of_pure_callback(self):
|
||||
|
||||
def sin(x):
|
||||
@ -933,34 +890,6 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
# callback alive.
|
||||
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
|
||||
|
||||
def test_callback_inside_xmap(self):
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
|
||||
def f(x):
|
||||
return jax.pure_callback(_callback, x, x)
|
||||
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
|
||||
axis_resources={'a': 'dev'})
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
|
||||
out = f(np.arange(40.))
|
||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||
|
||||
def test_vectorized_callback_inside_xmap(self):
|
||||
|
||||
def _callback(x):
|
||||
return (x + 1.).astype(x.dtype)
|
||||
|
||||
def f(x):
|
||||
return jax.pure_callback(_callback, x, x, vectorized=True)
|
||||
|
||||
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
|
||||
axis_resources={'a': 'dev'})
|
||||
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
|
||||
out = f(np.arange(40.))
|
||||
np.testing.assert_allclose(out, jnp.arange(1., 41.))
|
||||
|
||||
def test_array_layout_is_preserved(self):
|
||||
|
||||
def g(x):
|
||||
@ -1105,14 +1034,6 @@ class IOCallbackTest(jtu.JaxTestCase):
|
||||
ValueError, "Ordered effects not supported in `pmap`"):
|
||||
jax.pmap(f)(jnp.arange(jax.local_device_count()))
|
||||
|
||||
def test_cannot_call_ordered_io_in_xmap(self):
|
||||
def f(x):
|
||||
return io_callback(
|
||||
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Cannot `vmap` ordered IO callback"):
|
||||
maps.xmap(f, in_axes=([0],), out_axes=[0])(jnp.arange(16))
|
||||
|
||||
def test_cannot_call_ordered_io_in_vmap(self):
|
||||
def f(x):
|
||||
return io_callback(
|
||||
|
1944
tests/xmap_test.py
1944
tests/xmap_test.py
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user