mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
Also remove instantiate_const_outputs since that is unused PiperOrigin-RevId: 524113088
This commit is contained in:
parent
0fbaedf45e
commit
3e93833ed8
@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
deprecated. Please use `in_shardings` and `out_shardings` respectively.
|
||||
* The function `jax.numpy.msort` has been removed. It has been deprecated since
|
||||
JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead.
|
||||
* `in_parts` and `out_parts` arguments have been removed from `jax.xla_computation`
|
||||
since they were only used with sharded_jit and sharded_jit is long gone.
|
||||
|
||||
## jaxlib 0.4.9
|
||||
|
||||
|
@ -493,6 +493,15 @@ def xla_computation(fun: Callable,
|
||||
"""
|
||||
del instantiate_const_outputs # Unused
|
||||
|
||||
if in_parts is not None:
|
||||
raise ValueError(
|
||||
"in_parts has been deprecated. Please use the ahead of time APIs. You"
|
||||
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
|
||||
if out_parts is not None:
|
||||
raise ValueError(
|
||||
"out_parts has been deprecated. Please use the ahead of time APIs. You"
|
||||
" can read more here: https://jax.readthedocs.io/en/latest/aot.html")
|
||||
|
||||
check_callable(fun)
|
||||
static_argnums = _ensure_index_tuple(static_argnums)
|
||||
donate_argnums = _ensure_index_tuple(donate_argnums)
|
||||
@ -525,11 +534,6 @@ def xla_computation(fun: Callable,
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
|
||||
if in_parts is None:
|
||||
in_parts_flat = None
|
||||
else:
|
||||
in_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation in_parts", in_tree.children()[0], in_parts))
|
||||
jaxtree_fun, out_tree = flatten_fun(f, in_tree)
|
||||
avals = map(shaped_abstractify, args_flat)
|
||||
with ExitStack() as stack:
|
||||
@ -538,11 +542,6 @@ def xla_computation(fun: Callable,
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
|
||||
if out_parts is None:
|
||||
out_parts_flat = None
|
||||
else:
|
||||
out_parts_flat = tuple(flatten_axes(
|
||||
"xla_computation out_parts", out_tree(), out_parts))
|
||||
unordered_effects = list(
|
||||
effects.ordered_effects.filter_not_in(jaxpr.effects))
|
||||
ordered_effects = list(
|
||||
@ -558,10 +557,8 @@ def xla_computation(fun: Callable,
|
||||
name_stack=source_info_util.new_name_stack(
|
||||
wrap_name(fun_name, "xla_computation")),
|
||||
donated_args=donated_invars,
|
||||
arg_shardings=(None if in_parts_flat is None else map(
|
||||
xla.sharding_to_proto, in_parts_flat)),
|
||||
result_shardings=(None if out_parts_flat is None else map(
|
||||
xla.sharding_to_proto, out_parts_flat)))
|
||||
arg_shardings=None,
|
||||
result_shardings=None)
|
||||
if tuple_args is not None:
|
||||
should_tuple = tuple_args
|
||||
else:
|
||||
|
@ -2721,36 +2721,6 @@ class APITest(jtu.JaxTestCase):
|
||||
api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32))
|
||||
self.assertEqual(shape_tree, expected)
|
||||
|
||||
def test_xla_computation_partitioned(self):
|
||||
def f(x, y):
|
||||
return jnp.dot(x, y) + 1
|
||||
|
||||
x = jax.ShapeDtypeStruct((8, 8), np.float32)
|
||||
y = jax.ShapeDtypeStruct((8, 16), np.float32)
|
||||
xla_comp = api.xla_computation(f, in_parts=(P(2, 2), None),
|
||||
out_parts=P(4, 1))(x, y)
|
||||
hlo_text = xla_comp.as_hlo_text()
|
||||
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
|
||||
self.assertIn('sharding={replicated}', hlo_text)
|
||||
self.assertIn('sharding={{devices=[4,1]0,1,2,3}}', hlo_text)
|
||||
|
||||
def test_xla_computation_replicated_and_partitioned(self):
|
||||
def f(x, y):
|
||||
return jnp.dot(x, y), lax.psum(x, 'i')
|
||||
|
||||
x = jax.ShapeDtypeStruct((8, 8), np.float32)
|
||||
y = jax.ShapeDtypeStruct((8, 16), np.float32)
|
||||
axis_env = [('i', 4)]
|
||||
xla_comp = api.xla_computation(f, axis_env=axis_env,
|
||||
in_parts=(P(2, 2), None),
|
||||
out_parts=(P(4, 1), None))(x, y)
|
||||
hlo_text = xla_comp.as_hlo_text()
|
||||
self.assertIn('all-reduce', hlo_text)
|
||||
self.assertIn('replica_groups={{0,1,2,3}}', hlo_text)
|
||||
self.assertIn('sharding={devices=[2,2]0,1,2,3}', hlo_text)
|
||||
self.assertIn('sharding={replicated}', hlo_text)
|
||||
self.assertIn('sharding={{devices=[4,1]0,1,2,3}, {replicated}}', hlo_text)
|
||||
|
||||
def test_xla_computation_psum_constant(self):
|
||||
f = lambda: jax.lax.psum(1, "i")
|
||||
api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash
|
||||
|
Loading…
x
Reference in New Issue
Block a user