Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).

PiperOrigin-RevId: 655614395
This commit is contained in:
Yash Katariya 2024-07-24 10:23:29 -07:00 committed by jax authors
parent d9a7cb4490
commit 0d5dae09ff
33 changed files with 17 additions and 5801 deletions

View File

@ -10,6 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.31
* Deletion
* xmap has been deleted. Please use {func}`shard_map` as the replacement.
* Changes
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.

View File

@ -563,8 +563,7 @@ As XLA does not have enough knowledge about the custom functions to shard input
To avoid this duplication, we can:
- [custom_partitioning](https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html): to make it behave like all native JAX operations (but more complicated)
- Use manual sharding
- [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html): the new replacement for xmap
- [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) (now deprecated)
- [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html)
This example demonstrates the use of custom_partitioning.

View File

@ -20,7 +20,6 @@ This section contains examples and tutorials on more advanced topics, such as Mu
notebooks/Distributed_arrays_and_automatic_parallelization
notebooks/shard_map
distributed_data_loading
notebooks/xmap_tutorial
.. toctree::
:caption: Automatic Differentiation

View File

@ -214,7 +214,6 @@ nb_execution_excludepatterns = [
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
'jep/9407-type-promotion.*',
# TODO(jakevdp): enable execution on the following if possible:
'notebooks/xmap_tutorial.*',
'notebooks/Distributed_arrays_and_automatic_parallelization.*',
'notebooks/autodiff_remat.*',
# Requires accelerators

View File

@ -138,7 +138,7 @@ jax.grad(f)(1.)
#### Printing in other transformations
`jax.debug.print` also works in other transformations like `xmap` and `pjit`.
`jax.debug.print` also works in other transformations like `pjit`.
### More control with `jax.debug.callback`

View File

@ -1,12 +0,0 @@
``jax.experimental.maps`` module
================================
.. automodule:: jax.experimental.maps
API
---
.. autosummary::
:toctree: _autosummary
xmap

View File

@ -17,7 +17,6 @@ Experimental Modules
jax.experimental.array_api
jax.experimental.checkify
jax.experimental.host_callback
jax.experimental.maps
jax.experimental.pjit
jax.experimental.sparse
jax.experimental.jet

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -220,7 +220,6 @@ py_library_providing_imports_info(
"_src/interpreters/ad.py",
"_src/interpreters/batching.py",
"_src/interpreters/pxla.py",
"_src/maps.py",
"_src/pjit.py",
"_src/prng.py",
"_src/public_test_util.py",
@ -262,8 +261,6 @@ py_library_providing_imports_info(
],
) + [
"experimental/attrs.py",
# until new parallelism APIs are moved out of experimental
"experimental/maps.py",
"experimental/pjit.py",
"experimental/multihost_utils.py",
"experimental/shard_map.py",
@ -1041,17 +1038,6 @@ pytype_library(
deps = [":jax"],
)
# TODO(apaszke): Remove this target
pytype_library(
name = "maps",
srcs = ["experimental/maps.py"],
visibility = ["//visibility:public"],
deps = [
":jax",
":mesh",
],
)
# TODO(apaszke): Remove this target
pytype_library(
name = "pjit",

View File

@ -180,7 +180,6 @@ import jax.experimental.compilation_cache.compilation_cache as _ccache
del _ccache
from jax._src.deprecations import register as _register_deprecation
_register_deprecation("jax-experimental-maps-module")
_register_deprecation('jax-scipy-beta-args')
_register_deprecation('tracer-hash')
del _register_deprecation

View File

@ -192,7 +192,7 @@ class ArrayImpl(basearray.Array):
# Don't rearrange if skip_checks is enabled because this assumes that the
# input buffers are already arranged properly. This usually happens when
# Array's are created as output of a JAX transformation
# (like pjit, xmap, etc).
# (like pjit, etc).
if not _skip_checks or config.enable_checks.value:
self._check_and_rearrange()

File diff suppressed because it is too large Load Diff

View File

@ -68,7 +68,7 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
def hashed_index(x) -> int:
# This works for both `pjit`/`xmap` indices and `pmap` indices (which might
# This works for both `pjit` indices and `pmap` indices (which might
# have an integer instead of a slice).
assert all(v.step is None for v in x if isinstance(v, slice))
return hash(tuple((v.start, v.stop) if isinstance(v, slice) else v for v in x))
@ -1123,8 +1123,8 @@ class SPMDAxisContext:
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
This includes the mesh that will later by used to execute this computation,
as well as a set of mesh axes that are currently (e.g. because the current lowering
is invoked inside an xmap) lowered in the MANUAL sharding mode.
as well as a set of mesh axes that are currently lowered in the MANUAL
sharding mode.
"""
mesh: mesh_lib.Mesh
manual_axes: frozenset[MeshAxisName] = frozenset()

View File

@ -201,13 +201,13 @@ def _tpu_custom_call_lowering(
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
" call in a shard_map."
)
elif isinstance(axis_context, sharding_impls.ShardingContext):
if axis_context.num_devices != 1:
raise NotImplementedError(
"Mosaic kernels cannot be automatically partitioned. Please wrap the"
" call in a shard_map or xmap."
" call in a shard_map."
)
elif config.has_communication:
raise NotImplementedError(

View File

@ -26,7 +26,7 @@ is invoked from the other framework.
The native serialization mode has several advantages:
* supports virtually all operations supported by native execution, e.g.,
`xmap`, `shard_map`, `pmap`, parallel collective operations, and all
`shard_map`, `pmap`, parallel collective operations, and all
primitives at all data types.
* uses standard native JAX code paths for lowering, and thus it is easier
to trust that the semantics and performance stays faithful to the native

View File

@ -78,7 +78,7 @@ def call_tf(
function must return the same type of results.
If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
or :func:`jax.pmap`, or a control-flow primitive) then
``callable_tf`` will be compiled with ``tf.function(callable_tf,
jit_compile=True)``
and the resulting XLA computation will be embedded in JAX's XLA computation.

View File

@ -51,7 +51,6 @@ from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import maps
from jax._src import mesh
from jax._src import pjit
from jax._src import prng
@ -1484,7 +1483,7 @@ class TensorFlowTrace(core.Trace):
def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
assert False, f"Encountered unexpected primitive {p}"
for unexpected in [core.call_p, maps.xmap_p]:
for unexpected in [core.call_p]:
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
tf_impl[lax_control_flow.loops.eval_jaxpr_p] = \

View File

@ -33,7 +33,6 @@ from jax import numpy as jnp
from jax import sharding
from jax._src import config
from jax._src import core
from jax._src.maps import xmap
from jax._src import source_info_util
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
@ -1427,8 +1426,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
"none",
"jit",
"pjit", "pjit_in_shardings_None", "pjit_in_shardings_P",
"pjit_in_shardings_Sharding",
"shard_map", "xmap", "pmap"]
"pjit_in_shardings_Sharding", "shard_map", "pmap"]
for transform2 in (
["none", "pjit_in_shardings_None", "pjit_in_shardings_P",
"pjit_in_shardings_Sharding"]
@ -1483,8 +1481,6 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
shard_map=(
shard_map(func, mesh, in_specs=(P("a", None),),
out_specs=P("a", None))),
xmap=xmap(func, in_axes=({0: 'axis'},),
out_axes={0: 'axis'}, axis_resources={'axis': 'a'}),
pmap=jax.pmap(func, in_axes=0, out_axes=0),
)[transform]
return transformed_func
@ -1492,12 +1488,9 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
transformed1_func = apply_transform(
(func_shard_map if transform1 == "shard_map" else func),
transform1)
assert transform2 not in ["xmap", "shard_map"]
assert transform2 not in ["shard_map"]
transformed2_func = apply_transform(transformed1_func, transform2)
if transform1 == "xmap" and transform2 in ["pjit", "none"]:
raise unittest.SkipTest("TODO: pjit(xmap) with unspecified shardings crashes")
if transform1 == "pmap":
x = x.reshape((1, -1)) # Since we use 1 device
if not nullary:

View File

@ -31,7 +31,6 @@ from absl.testing import absltest
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src import maps # Needed for config flags.
import numpy as np

View File

@ -33,7 +33,6 @@ from absl.testing import absltest
import jax
from jax._src import compiler
from jax._src import config
from jax._src import maps
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax import lax

View File

@ -1,49 +0,0 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from jax._src import deprecations
from jax._src.maps import (
AxisName as AxisName,
ResourceSet as ResourceSet,
SerialLoop as SerialLoop,
_prepare_axes as _prepare_axes,
make_xmap_callable as make_xmap_callable,
serial_loop as serial_loop,
xmap_p as xmap_p,
xmap as xmap,
)
from jax._src.mesh import (
EMPTY_ENV as EMPTY_ENV,
ResourceEnv as ResourceEnv,
thread_resources as thread_resources,
)
# Added March 7, 2024.
_msg = (
"jax.experimental.maps and jax.experimental.maps.xmap are deprecated and"
" will be removed in a future release. Use jax.experimental.shard_map or"
" jax.vmap with the spmd_axis_name argument for expressing SPMD"
" device-parallel computations. Please file an issue on"
" https://github.com/google/jax/issues if neither"
" jax.experimental.shard_map nor jax.vmap are suitable for your use case."
)
if deprecations.is_accelerated("jax-experimental-maps-module"):
raise ImportError(_msg)
else:
warnings.warn(_msg, DeprecationWarning, stacklevel=2)
del deprecations, warnings, _msg

View File

@ -216,30 +216,6 @@ py_test(
],
)
jax_test(
name = "xmap_test",
srcs = ["xmap_test.py"],
backend_tags = {
"gpu": [
"noasan", # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143
],
"tpu": [
"noasan", # Times out.
"nomsan", # Times out.
"notsan", # Times out.
],
},
shard_count = {
"cpu": 10,
"gpu": 4,
"tpu": 4,
},
tags = ["multiaccelerator"],
deps = [
"//jax:maps",
],
)
jax_test(
name = "memories_test",
srcs = ["memories_test.py"],

View File

@ -41,7 +41,6 @@ from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client
from jax._src.maps import xmap
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
import numpy as np
@ -235,25 +234,6 @@ class CompilationCacheTest(CompilationCacheTestCase):
f(x, x + 1)
self.assertEqual(count_cache_items(), 2)
@jtu.with_mesh([("x", 2)])
def test_xmap(self):
def f(x):
return x * 2
devices = np.array(jax.local_devices()[:2])
if devices.size < 2:
raise SkipTest("Test requires 2 devices")
x = np.arange(8, dtype=np.int64).reshape((2, 2, 2))
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
self.assertEqual(count_cache_items(), 1)
x = np.arange(8, dtype=np.float32).reshape((2, 2, 2))
xmap(
f, in_axes=["a", ...], out_axes=["a", ...], axis_resources={"a": "x"}
)(x)
self.assertEqual(count_cache_items(), 2)
def test_cache_write_warning(self):
f = jit(lambda x: x * x)

View File

@ -24,7 +24,6 @@ from jax._src import api
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit
from jax._src.maps import xmap
jax.config.parse_flags_with_absl()
@ -114,28 +113,6 @@ class DebugNaNsTest(jtu.JaxTestCase):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testXmap(self):
f = xmap(
lambda x: 0. / x,
in_axes=["i"],
out_axes=["i"],
axis_resources={"i": "x"})
with jax.sharding.Mesh(np.array(jax.local_devices()[:1]), ('x',)):
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in xmap"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testPjit(self):
if jax.device_count() < 2:

View File

@ -26,7 +26,6 @@ from jax._src import ad_checkpoint
from jax._src import debugging
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src.maps import xmap
import jax.numpy as jnp
import numpy as np
@ -786,40 +785,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
"[ 3 4 5 6 7 8 9 10]\n"
"[ 4 5 6 7 8 9 10 11]\n")
def test_unordered_print_of_pjit_of_xmap(self):
def f(x):
def foo(x):
idx = lax.axis_index('foo')
debug_print("{idx}: {x}", idx=idx, x=x)
return jnp.mean(x, axis=['foo'])
out = xmap(foo, in_axes=['foo'], out_axes=[...])(x)
debug_print("Out: {}", out)
return out
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
in_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('dev'))
out_spec = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
f = pjit.pjit(f, in_shardings=in_spec, out_shardings=out_spec)
with mesh:
with jtu.capture_stdout() as output:
f(jnp.arange(8, dtype=jnp.int32) * 2)
lines = ["0: 0", "1: 2", "2: 4", "3: 6", "4: 8", "5: 10", "6: 12",
"7: 14", "Out: 7.0", ""]
jax.effects_barrier()
self._assertLinesEqual(output(), "\n".join(lines))
def test_unordered_print_with_xmap(self):
def f(x):
debug_print("{}", x, ordered=False)
f = xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
axis_resources={'a': 'dev'})
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
with jtu.capture_stdout() as output:
f(np.arange(40))
jax.effects_barrier()
lines = [f"{i}\n" for i in range(40)]
self._assertLinesEqual(output(), "".join(lines))
def test_unordered_print_works_in_pmap_of_while(self):
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")

View File

@ -31,7 +31,6 @@ from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.maps import xmap
import numpy as np
config.parse_flags_with_absl()
@ -265,15 +264,6 @@ class HigherOrderPrimitiveTest(jtu.JaxTestCase):
r"Ordered effects not supported for map primitives: \[.*\]"):
jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def test_xmap_inherits_effects(self):
def f(x):
effect_p.bind(effect=foo_effect)
effect_p.bind(effect=bar_effect)
return x
f = xmap(f, in_axes=['a'], out_axes=['a'])
jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
def test_pjit_inherits_effects(self):
def f(x):
effect_p.bind(effect=foo_effect)

View File

@ -41,7 +41,6 @@ import jax.scipy as jsp
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax.control_flow import for_loop
from jax._src.interpreters import mlir
from jax._src.maps import xmap
jax.config.parse_flags_with_absl()
@ -2710,19 +2709,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.scan(side_effecting_scan, None, jnp.ones((2, 2)))
lst[0] += 1
def test_while_loop_fixed_point_with_nested_named_axes(self):
def f(x):
z = x + lax.axis_index('a').astype(x.dtype)
y = x + lax.axis_index('b').astype(x.dtype)
def cond(carry):
i, x = carry
return x < 5
def body(carry):
i, x = carry
return i + 1, x + lax.psum(y, 'b')
return lax.while_loop(cond, body, (0, z))[1]
xmap(f, axis_sizes=dict(a=2, b=10), out_axes=(['a']), in_axes={})(1.)
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
def f(i, x):
def cond(carry):

View File

@ -266,7 +266,6 @@ class MultiProcessGpuTest(jtu.JaxTestCase):
os.environ.get("SLURM_JOB_NUM_NODES", None) != "2",
"Slurm environment with at least two nodes needed!")
@jtu.pytest_mark_if_available('SlurmMultiNodeGpuTest')
@jtu.with_config(experimental_xmap_spmd_lowering=True)
class SlurmMultiNodeGpuTest(jtu.JaxTestCase):
def sorted_devices(self):

View File

@ -35,7 +35,6 @@ from jax._src import config
from jax._src import test_util as jtu
from jax import dtypes
from jax import stages
from jax.errors import JAXTypeError
from jax import lax
from jax._src.lax import lax as lax_internal
from jax.lax import with_sharding_constraint
@ -51,11 +50,9 @@ from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding, parse_flatten_op_sharding)
import jax._src.pjit as pjit_lib
from jax._src.maps import xmap
from jax._src.pjit import pjit, pjit_p
from jax._src.pjit import pjit
from jax._src import mesh as mesh_lib
from jax._src.interpreters import pxla
from jax.interpreters import mlir
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
@ -69,7 +66,6 @@ _exit_stack = contextlib.ExitStack()
def setUpModule():
_exit_stack.enter_context(jtu.set_host_platform_device_count(8))
_exit_stack.enter_context(jtu.global_config_context(experimental_xmap_spmd_lowering=True))
def tearDownModule():
_exit_stack.close()
@ -749,25 +745,6 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertListEqual(op.tile_assignment_devices(), [0, 1])
self.assertFalse(op_shardings.is_op_sharding_replicated(op))
@jtu.with_mesh([('x', 2), ('y', 1)])
def testShardingInXMap(self):
h = pjit(lambda x: x, in_shardings=P('x'), out_shardings=None)
f = xmap(lambda x: h(x * 2), in_axes=['i', ...], out_axes=['i', ...],
axis_resources={'i': 'y'})
x = jnp.arange(16).reshape((4, 4))
rule = mlir._lowerings[pjit_p]
test_rule_called = False
def _test_rule(*args, **kwargs):
nonlocal test_rule_called
test_rule_called = True
return rule(*args, **kwargs)
try:
mlir._lowerings[pjit_p] = _test_rule
f(x)
self.assertTrue(test_rule_called)
finally:
mlir._lowerings[pjit_p] = rule
@jtu.with_mesh([('x', 2)])
def testLowerWithDuckTyping(self):
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
@ -3299,36 +3276,6 @@ class ArrayPjitTest(jtu.JaxTestCase):
jax.device_put(2., NamedSharding(mesh, P())) # doesn't crash
@jtu.with_mesh([('x', 2), ('y', 1)])
def test_jit_nested_xmap_lower_arg_info(self):
def f(x, y, *args):
out = xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
axis_resources={'i': 'y'})(jnp.arange(8.))
return y['hi'] + args[1], out
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
{'hi': 1.}, {'hi': 2.}, 3., 4.)
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
# TODO(yashkatariya): Add keep_unused support to lower_mesh_computation
# and then uncomment the below line.
# self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
@jtu.with_mesh([('x', 2), ('y', 1)])
def test_jit_nested_xmap_lower_result_info(self):
def f(x, y, z):
_ = xmap(lambda x: x * 2, in_axes=['i', ...], out_axes=['i', ...],
axis_resources={'i': 'y'})(jnp.arange(8.))
return {'a': x, 'b': [y]}
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
1., (2.,), [3.])
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def test_with_sharding_constraint_with_two_meshes(self):
if jax.device_count() < 4:
self.skipTest("Requires more than 4 devices.")
@ -4481,72 +4428,6 @@ class PJitErrorTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, error):
pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x)
@jtu.with_mesh([('x', 2)])
def testInputShardsXMapAxis(self):
spec = P('x')
f = xmap(
pjit(lambda x: x + 2, in_shardings=spec, out_shardings=None),
in_axes=['i', ...],
out_axes=['i', ...],
axis_resources={'i': 'x'},
)
x = jnp.arange(4).reshape((2, 2))
error = (r"pjit input has an axis resources specification of " +
spec_regex(spec) + r" that uses one or more "
"mesh axes already used by "
r"xmap to partition a named axis appearing in its named_shape \(both "
r"use mesh axes `x`\)")
with self.assertRaisesRegex(JAXTypeError, error):
f(x)
@jtu.with_mesh([('x', 2)])
def testOutputShardsXMapAxis(self):
spec = P('x')
f = xmap(
pjit(lambda x: x + 2, in_shardings=None, out_shardings=spec),
in_axes=['i', ...],
out_axes=['i', ...],
axis_resources={'i': 'x'},
)
x = jnp.arange(4).reshape((2, 2))
error = (r"pjit output has an axis resources specification of " +
spec_regex(spec) + r" that uses one or more "
"mesh axes already used by "
r"xmap to partition a named axis appearing in its named_shape \(both "
r"use mesh axes `x`\)")
with self.assertRaisesRegex(JAXTypeError, error):
f(x)
@jtu.with_mesh([('x', 2)])
def testConstraintShardsXMapAxis(self):
spec = P('x')
f = xmap(lambda x: with_sharding_constraint(x, spec),
in_axes=['i', ...], out_axes=['i', ...], axis_resources={'i': 'x'})
x = jnp.arange(4).reshape((2, 2))
error = (r"with_sharding_constraint input has an axis resources specification of " +
spec_regex(spec) + r" that uses one or more "
"mesh axes already used by "
r"xmap to partition a named axis appearing in its named_shape \(both "
r"use mesh axes `x`\)")
with self.assertRaisesRegex(JAXTypeError, error):
f(x)
@jtu.with_mesh([('x', 2)])
def testCatchesInnerXMapErrors(self):
f = pjit(
xmap(
lambda x, y: x,
in_axes=(['i'], ['j']),
out_axes=['i', 'j'],
axis_resources={'i': 'x', 'j': 'x'},
),
in_shardings=None,
out_shardings=None,
)
x = jnp.arange(4)
with self.assertRaises(JAXTypeError):
f(x, x)
def testEmptyMesh(self):
out = pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))
self.assertEqual(out.sharding, SingleDeviceSharding(jax.devices()[0]))

View File

@ -94,10 +94,6 @@ def args_slicer(args, bdims):
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning, message=".*jit-of-pmap.*")
ignore_xmap_warning = partial(
jtu.ignore_warning, message=".*is an experimental.*")
def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
devices=None, sharded_dim_size=None):
if input_data is None:
@ -1950,7 +1946,6 @@ class PythonPmapTest(jtu.JaxTestCase):
indices = np.array([[[2], [1]], [[0], [0]]])
mapped_fn(indices) # doesn't crash
@ignore_xmap_warning()
def testPdotBasic(self):
num_devices = jax.device_count()

View File

@ -26,11 +26,8 @@ from jax import lax
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import maps
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import xla_client
from jax._src.maps import xmap
from jax.experimental import io_callback
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
@ -739,46 +736,6 @@ class PureCallbackTest(jtu.JaxTestCase):
out = f(jnp.arange(float(jax.local_device_count())))
np.testing.assert_allclose(out, np.sin(np.arange(jax.local_device_count())))
def test_can_pjit_pure_callback_under_hard_xmap(self):
if not hasattr(xla_client.OpSharding.Type, 'MANUAL'):
raise unittest.SkipTest('Manual partitioning needed for pure_callback')
spmd_lowering = maps.SPMD_LOWERING.value
spmd_manual_lowering = maps.SPMD_LOWERING_MANUAL.value
config.update('experimental_xmap_spmd_lowering', True)
config.update('experimental_xmap_spmd_lowering_manual', True)
try:
mesh = Mesh(np.array(jax.devices()), axis_names=('x',))
spec = jax.sharding.PartitionSpec('x')
def f(x):
axis_resources = {v: v for v in mesh.axis_names}
return xmap(
lambda x: jax.pure_callback(np.sin, x, x),
in_axes=(('x',),),
out_axes=('x',),
axis_resources=axis_resources,
axis_sizes=mesh.shape,
)(x)
def without_xmap_f(x):
return jax.pure_callback(np.sin, x, x)
with mesh:
inp = jnp.arange(float(jax.local_device_count()))
out = pjit.pjit(f, in_shardings=spec, out_shardings=spec)(inp)
np.testing.assert_allclose(
out, np.sin(np.arange(jax.local_device_count()))
)
finally:
config.update('experimental_xmap_spmd_lowering', spmd_lowering)
config.update(
'experimental_xmap_spmd_lowering_manual',
spmd_manual_lowering,
)
def test_cant_take_grad_of_pure_callback(self):
def sin(x):
@ -933,34 +890,6 @@ class PureCallbackTest(jtu.JaxTestCase):
# callback alive.
np.testing.assert_allclose(out, np.full((num_devices, 4), 11, np.float32))
def test_callback_inside_xmap(self):
def _callback(x):
return (x + 1.).astype(x.dtype)
def f(x):
return jax.pure_callback(_callback, x, x)
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'})
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
out = f(np.arange(40.))
np.testing.assert_allclose(out, jnp.arange(1., 41.))
def test_vectorized_callback_inside_xmap(self):
def _callback(x):
return (x + 1.).astype(x.dtype)
def f(x):
return jax.pure_callback(_callback, x, x, vectorized=True)
f = maps.xmap(f, in_axes=['a'], out_axes=['a'],
axis_resources={'a': 'dev'})
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
out = f(np.arange(40.))
np.testing.assert_allclose(out, jnp.arange(1., 41.))
def test_array_layout_is_preserved(self):
def g(x):
@ -1105,14 +1034,6 @@ class IOCallbackTest(jtu.JaxTestCase):
ValueError, "Ordered effects not supported in `pmap`"):
jax.pmap(f)(jnp.arange(jax.local_device_count()))
def test_cannot_call_ordered_io_in_xmap(self):
def f(x):
return io_callback(
lambda x: x, jax.ShapeDtypeStruct((), jnp.int32), x, ordered=True)
with self.assertRaisesRegex(
ValueError, "Cannot `vmap` ordered IO callback"):
maps.xmap(f, in_axes=([0],), out_axes=[0])(jnp.arange(16))
def test_cannot_call_ordered_io_in_vmap(self):
def f(x):
return io_callback(

File diff suppressed because it is too large Load Diff