Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone

Also remove instantiate_const_outputs since that is unused

PiperOrigin-RevId: 524113088
This commit is contained in:
Yash Katariya 2023-04-13 15:04:48 -07:00 committed by jax authors
parent 0fbaedf45e
commit 3e93833ed8
3 changed files with 13 additions and 44 deletions

View File

@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
deprecated. Please use `in_shardings` and `out_shardings` respectively.
* The function `jax.numpy.msort` has been removed. It has been deprecated since
JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead.
* `in_parts` and `out_parts` arguments have been removed from `jax.xla_computation`
since they were only used with sharded_jit and sharded_jit is long gone.
## jaxlib 0.4.9

View File

@ -493,6 +493,15 @@ def xla_computation(fun: Callable,
"""
del instantiate_const_outputs # Unused
if in_parts is not None:
raise ValueError(
"in_parts has been deprecated. Please use the ahead of time APIs. You"
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
if out_parts is not None:
raise ValueError(
"out_parts has been deprecated. Please use the ahead of time APIs. You"
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
check_callable(fun)
static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
@ -525,11 +534,6 @@ def xla_computation(fun: Callable,
else:
donated_invars = (False,) * len(args_flat)
if in_parts is None:
in_parts_flat = None
else:
in_parts_flat = tuple(flatten_axes(
"xla_computation in_parts", in_tree.children()[0], in_parts))
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
avals = map(shaped_abstractify, args_flat)
with ExitStack() as stack:
@ -538,11 +542,6 @@ def xla_computation(fun: Callable,
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
if out_parts is None:
out_parts_flat = None
else:
out_parts_flat = tuple(flatten_axes(
"xla_computation out_parts", out_tree(), out_parts))
unordered_effects = list(
effects.ordered_effects.filter_not_in(jaxpr.effects))
ordered_effects = list(
@ -558,10 +557,8 @@ def xla_computation(fun: Callable,
name_stack=source_info_util.new_name_stack(
wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,
arg_shardings=(None if in_parts_flat is None else map(
xla.sharding_to_proto, in_parts_flat)),
result_shardings=(None if out_parts_flat is None else map(
xla.sharding_to_proto, out_parts_flat)))
arg_shardings=None,
result_shardings=None)
if tuple_args is not None:
should_tuple = tuple_args
else:

View File

@ -2721,36 +2721,6 @@ class APITest(jtu.JaxTestCase):
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
self.assertEqual(shape_tree, expected)
def test_xla_computation_partitioned(self):
def f(x, y):
return jnp.dot(x, y) + 1
x = jax.ShapeDtypeStruct((8, 8), np.float32)
y = jax.ShapeDtypeStruct((8, 16), np.float32)
xla_comp = api.xla_computation(f, in_parts=(P(2, 2), None),
out_parts=P(4, 1))(x, y)
hlo_text = xla_comp.as_hlo_text()
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
self.assertIn('sharding={replicated}', hlo_text)
self.assertIn('sharding={{devices=[4,1]0,1,2,3}}', hlo_text)
def test_xla_computation_replicated_and_partitioned(self):
def f(x, y):
return jnp.dot(x, y), lax.psum(x, 'i')
x = jax.ShapeDtypeStruct((8, 8), np.float32)
y = jax.ShapeDtypeStruct((8, 16), np.float32)
axis_env = [('i', 4)]
xla_comp = api.xla_computation(f, axis_env=axis_env,
in_parts=(P(2, 2), None),
out_parts=(P(4, 1), None))(x, y)
hlo_text = xla_comp.as_hlo_text()
self.assertIn('all-reduce', hlo_text)
self.assertIn('replica_groups={{0,1,2,3}}', hlo_text)
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
self.assertIn('sharding={replicated}', hlo_text)
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)
def test_xla_computation_psum_constant(self):
f = lambda: jax.lax.psum(1, "i")
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash