compiled.input_layouts() should preserve the structure of the original in_tree.

JAX by default DCE's arguments that are unused which changes the in_layouts available on the `executable`. This breaks when we try to unflatten the said in_layouts with the original in_tree (because in_tree has all the args DCE'd + non-DCE'd).

The in_layouts that we return to the user should contain layouts for DCE'd + non-DCE'd args. So fill the DCE'd layouts with None which means the default layout. This does not affect the actual HLO computation because JAX will discard the DCE'd layouts anyways, consequently discarding the jax.Arrays created with those layouts.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 585790912
This commit is contained in:
Yash Katariya 2023-11-27 16:23:24 -08:00 committed by jax authors
parent 36b3a3211d
commit 2ed0fc4d1c
2 changed files with 27 additions and 0 deletions

View File

@ -501,6 +501,10 @@ class Compiled(Stage):
def _input_layouts(self):
layouts_flat = self._executable.input_layouts()
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
# Some input layouts got DCE'd
if self.in_tree.num_leaves > len(layouts_flat):
layouts_flat = [layouts_flat[i] if i in self._executable._kept_var_idx
else None for i in range(self.in_tree.num_leaves)]
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
def _output_layouts(self):

View File

@ -22,6 +22,7 @@ from jax.sharding import NamedSharding, PartitionSpec as P
from jax._src import config
from jax._src import layout
from jax._src import test_util as jtu
from jax._src.util import safe_zip
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
@ -164,6 +165,28 @@ class LayoutTest(jtu.JaxTestCase):
self.assertArraysEqual(out, np_inp.T)
self.assertEqual(out.sharding, s)
def test_dce_in_layouts(self):
def f(x, y):
return x * 2
shape = (8, 2)
inp = np.arange(math.prod(shape)).reshape(shape)
compiled = jax.jit(f).lower(inp, inp, _in_layouts=layout.AUTO,
_out_layouts=layout.AUTO).compile()
arg_layouts, _ = compiled._input_layouts()
out1 = compiled(inp, inp)
compiled2 = jax.jit(f).lower(inp, inp, _in_layouts=arg_layouts).compile()
out2 = compiled2(inp, inp)
for l1, l2 in safe_zip(arg_layouts, compiled2._input_layouts()[0]):
self.assertEqual(l1, l2)
self.assertArraysEqual(out1, out2)
# TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array
# and then pass that back into compiled.
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())