mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Delete soft_pmap as it has no users. Please use pjit
or xmap
if you do want soft_pmap.
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided. PiperOrigin-RevId: 474145090
This commit is contained in:
parent
dc7db8d1b4
commit
da90234cae
@ -18,6 +18,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Breaking changes
|
||||
* `jax._src` is no longer imported into the from the public `jax` namespace.
|
||||
This may break users that were using JAX internals.
|
||||
* `jax.soft_pmap` has been deleted. Please use `pjit` or `xmap` instead.
|
||||
`jax.soft_pmap` is undocumented. If it were documented, a deprecation period
|
||||
would have been provided.
|
||||
|
||||
## jax 0.3.17 (Aug 31, 2022)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17).
|
||||
|
@ -112,7 +112,6 @@ from jax._src.api import (
|
||||
xla, # TODO(phawkins): update users to avoid this.
|
||||
xla_computation as xla_computation,
|
||||
)
|
||||
from jax.experimental.maps import soft_pmap as soft_pmap
|
||||
from jax.version import __version__ as __version__
|
||||
from jax.version import __version_info__ as __version_info__
|
||||
|
||||
|
@ -1899,29 +1899,6 @@ class NoQuotesStr(str):
|
||||
__repr__ = str.__str__
|
||||
|
||||
|
||||
# -------- soft_pmap --------
|
||||
|
||||
def soft_pmap(fun: Callable, axis_name: Optional[AxisName] = None, in_axes=0
|
||||
) -> Callable:
|
||||
warn("soft_pmap is an experimental feature and probably has bugs!")
|
||||
_check_callable(fun)
|
||||
axis_name = core._TempAxisName(fun) if axis_name is None else axis_name
|
||||
|
||||
if any(axis != 0 for axis in tree_leaves(in_axes)):
|
||||
raise ValueError(f"soft_pmap in_axes leaves must be 0 or None, got {in_axes}")
|
||||
proxy = object()
|
||||
in_axes = _replace_nones(proxy, in_axes)
|
||||
in_axes = tree_map(lambda i: {i: axis_name} if i is not proxy else {}, in_axes)
|
||||
|
||||
|
||||
@wraps(fun)
|
||||
def f_pmapped(*args, **kwargs):
|
||||
mesh_devices = np.array(xb.local_devices())
|
||||
with Mesh(mesh_devices, ['devices']):
|
||||
return xmap(fun, in_axes=in_axes, out_axes={0: axis_name},
|
||||
axis_resources={axis_name: 'devices'})(*args, **kwargs)
|
||||
return f_pmapped
|
||||
|
||||
# -------- config flags --------
|
||||
|
||||
def _thread_local_flag_unsupported(_):
|
||||
|
@ -38,7 +38,7 @@ from jax._src.lax import parallel
|
||||
from jax._src import api as src_api
|
||||
from jax import random
|
||||
from jax.core import ShapedArray
|
||||
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import device_array
|
||||
@ -1545,128 +1545,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
r_db = r.device_buffers
|
||||
self.assertEqual(len(r_db), 6)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmul(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
||||
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
||||
ans = soft_pmap(jnp.dot, 'i')(xs, ys)
|
||||
expected = np.einsum('nij,njk->nik', xs, ys)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapBatchMatmulJit(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
||||
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
||||
ans = soft_pmap(jit(jnp.dot), 'i')(xs, ys)
|
||||
expected = np.einsum('nij,njk->nik', xs, ys)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsumConstant(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
def f(_):
|
||||
return lax.psum(1, 'i')
|
||||
ans = soft_pmap(f, 'i')(jnp.ones(n))
|
||||
expected = n * np.ones(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapPsum(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return x / lax.psum(x, 'i')
|
||||
ans = soft_pmap(f, 'i')(jnp.ones(n))
|
||||
expected = np.ones(n) / n
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapAxisIndex(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return x * lax.axis_index('i')
|
||||
ans = soft_pmap(f, 'i')(2 * jnp.ones(n, dtype='int32'))
|
||||
expected = 2 * np.arange(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapOfJit(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return 3 * x
|
||||
ans = soft_pmap(jit(f), 'i')(np.arange(n))
|
||||
expected = 3 * np.arange(n)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@unittest.skip("not implemented") # TODO(mattjj): re-implement
|
||||
def testSoftPmapNested(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
|
||||
@partial(soft_pmap, axis_name='i')
|
||||
@partial(soft_pmap, axis_name='j')
|
||||
def f(x):
|
||||
i_size = lax.psum(1, 'i')
|
||||
return x + lax.axis_index('i') + i_size * lax.axis_index('j')
|
||||
|
||||
ans = f(jnp.zeros((n, n)))
|
||||
expected = np.arange(n ** 2).reshape(n, n).T
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@unittest.skip("not implemented") # TODO(mattjj): re-implement
|
||||
def testGradOfSoftPmap(self):
|
||||
n = 4 * jax.device_count()
|
||||
|
||||
@partial(soft_pmap, axis_name='i')
|
||||
def f(x):
|
||||
return x * lax.axis_index('i')
|
||||
|
||||
ans = grad(lambda x: jnp.sum(f(x)))(jnp.zeros((n, n)))
|
||||
expected = np.repeat(np.arange(n)[:, None], n, axis=1)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ignore_xmap_warning()
|
||||
def testSoftPmapDevicePersistence(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
device_count = jax.device_count()
|
||||
shape = (2 * 2 * device_count, 2, 3)
|
||||
|
||||
# check that we can maintain device persistence across calls
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
x = soft_pmap(lambda x: x)(x)
|
||||
self.assertIsInstance(x, pxla.ShardedDeviceArray)
|
||||
x._npy_value = np.float32(np.nan) # can't be coerced to ndarray for xfer
|
||||
x = soft_pmap(lambda x: x)(x) # doesn't crash
|
||||
self.assertIsInstance(x, pxla.ShardedDeviceArray)
|
||||
|
||||
@unittest.skip("the underlying code here is broken") # TODO(mattjj)
|
||||
def testSoftPmapAllToAll(self):
|
||||
if config.jax_array:
|
||||
raise unittest.SkipTest('Does not work with `Array`.')
|
||||
n = 4 * jax.device_count()
|
||||
def f(x):
|
||||
return lax.all_to_all(x, 'i', 0, 0)
|
||||
ans = soft_pmap(f, 'i')(jnp.arange(n ** 2).reshape(n, n))
|
||||
expected = np.arange(n ** 2).reshape(n, n).T
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
def testShardedDeviceArrayBlockUntilReady(self):
|
||||
x = np.arange(jax.device_count())
|
||||
x = self.pmap(lambda x: x)(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user