From 3e93833ed8e0e1348f1cc1b3a21469036edccfb5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 13 Apr 2023 15:04:48 -0700 Subject: [PATCH] Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone Also remove instantiate_const_outputs since that is unused PiperOrigin-RevId: 524113088 --- CHANGELOG.md | 2 ++ jax/_src/api.py | 25 +++++++++++-------------- tests/api_test.py | 30 ------------------------------ 3 files changed, 13 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8489c46fd..298f74764 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list deprecated. Please use `in_shardings` and `out_shardings` respectively. * The function `jax.numpy.msort` has been removed. It has been deprecated since JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead. + * `in_parts` and `out_parts` arguments have been removed from `jax.xla_computation` + since they were only used with sharded_jit and sharded_jit is long gone. ## jaxlib 0.4.9 diff --git a/jax/_src/api.py b/jax/_src/api.py index bae660441..3f1069390 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -493,6 +493,15 @@ def xla_computation(fun: Callable, """ del instantiate_const_outputs # Unused + if in_parts is not None: + raise ValueError( + "in_parts has been deprecated. Please use the ahead of time APIs. You" + " can read more here: https://jax.readthedocs.io/en/latest/aot.html") + if out_parts is not None: + raise ValueError( + "out_parts has been deprecated. Please use the ahead of time APIs. You" + " can read more here: https://jax.readthedocs.io/en/latest/aot.html") + check_callable(fun) static_argnums = _ensure_index_tuple(static_argnums) donate_argnums = _ensure_index_tuple(donate_argnums) @@ -525,11 +534,6 @@ def xla_computation(fun: Callable, else: donated_invars = (False,) * len(args_flat) - if in_parts is None: - in_parts_flat = None - else: - in_parts_flat = tuple(flatten_axes( - "xla_computation in_parts", in_tree.children()[0], in_parts)) jaxtree_fun, out_tree = flatten_fun(f, in_tree) avals = map(shaped_abstractify, args_flat) with ExitStack() as stack: @@ -538,11 +542,6 @@ def xla_computation(fun: Callable, jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr)) - if out_parts is None: - out_parts_flat = None - else: - out_parts_flat = tuple(flatten_axes( - "xla_computation out_parts", out_tree(), out_parts)) unordered_effects = list( effects.ordered_effects.filter_not_in(jaxpr.effects)) ordered_effects = list( @@ -558,10 +557,8 @@ def xla_computation(fun: Callable, name_stack=source_info_util.new_name_stack( wrap_name(fun_name, "xla_computation")), donated_args=donated_invars, - arg_shardings=(None if in_parts_flat is None else map( - xla.sharding_to_proto, in_parts_flat)), - result_shardings=(None if out_parts_flat is None else map( - xla.sharding_to_proto, out_parts_flat))) + arg_shardings=None, + result_shardings=None) if tuple_args is not None: should_tuple = tuple_args else: diff --git a/tests/api_test.py b/tests/api_test.py index 18a540c83..878e7d3ca 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2721,36 +2721,6 @@ class APITest(jtu.JaxTestCase): api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32)) self.assertEqual(shape_tree, expected) - def test_xla_computation_partitioned(self): - def f(x, y): - return jnp.dot(x, y) + 1 - - x = jax.ShapeDtypeStruct((8, 8), np.float32) - y = jax.ShapeDtypeStruct((8, 16), np.float32) - xla_comp = api.xla_computation(f, in_parts=(P(2, 2), None), - out_parts=P(4, 1))(x, y) - hlo_text = xla_comp.as_hlo_text() - self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text) - self.assertIn('sharding={replicated}', hlo_text) - self.assertIn('sharding={{devices=[4,1]0,1,2,3}}', hlo_text) - - def test_xla_computation_replicated_and_partitioned(self): - def f(x, y): - return jnp.dot(x, y), lax.psum(x, 'i') - - x = jax.ShapeDtypeStruct((8, 8), np.float32) - y = jax.ShapeDtypeStruct((8, 16), np.float32) - axis_env = [('i', 4)] - xla_comp = api.xla_computation(f, axis_env=axis_env, - in_parts=(P(2, 2), None), - out_parts=(P(4, 1), None))(x, y) - hlo_text = xla_comp.as_hlo_text() - self.assertIn('all-reduce', hlo_text) - self.assertIn('replica_groups={{0,1,2,3}}', hlo_text) - self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text) - self.assertIn('sharding={replicated}', hlo_text) - self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text) - def test_xla_computation_psum_constant(self): f = lambda: jax.lax.psum(1, "i") api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash