mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
compiled.input_layouts()
should preserve the structure of the original in_tree.
JAX by default DCE's arguments that are unused which changes the in_layouts available on the `executable`. This breaks when we try to unflatten the said in_layouts with the original in_tree (because in_tree has all the args DCE'd + non-DCE'd). The in_layouts that we return to the user should contain layouts for DCE'd + non-DCE'd args. So fill the DCE'd layouts with None which means the default layout. This does not affect the actual HLO computation because JAX will discard the DCE'd layouts anyways, consequently discarding the jax.Arrays created with those layouts. Co-authored-by: Roy Frostig <frostig@google.com> PiperOrigin-RevId: 585790912
This commit is contained in:
parent
36b3a3211d
commit
2ed0fc4d1c
@ -501,6 +501,10 @@ class Compiled(Stage):
|
|||||||
def _input_layouts(self):
|
def _input_layouts(self):
|
||||||
layouts_flat = self._executable.input_layouts()
|
layouts_flat = self._executable.input_layouts()
|
||||||
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
|
assert all(isinstance(l, SpecifiedLayout) for l in layouts_flat)
|
||||||
|
# Some input layouts got DCE'd
|
||||||
|
if self.in_tree.num_leaves > len(layouts_flat):
|
||||||
|
layouts_flat = [layouts_flat[i] if i in self._executable._kept_var_idx
|
||||||
|
else None for i in range(self.in_tree.num_leaves)]
|
||||||
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
|
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
|
||||||
|
|
||||||
def _output_layouts(self):
|
def _output_layouts(self):
|
||||||
|
@ -22,6 +22,7 @@ from jax.sharding import NamedSharding, PartitionSpec as P
|
|||||||
from jax._src import config
|
from jax._src import config
|
||||||
from jax._src import layout
|
from jax._src import layout
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
|
from jax._src.util import safe_zip
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.lib import xla_extension_version
|
from jax._src.lib import xla_extension_version
|
||||||
|
|
||||||
@ -164,6 +165,28 @@ class LayoutTest(jtu.JaxTestCase):
|
|||||||
self.assertArraysEqual(out, np_inp.T)
|
self.assertArraysEqual(out, np_inp.T)
|
||||||
self.assertEqual(out.sharding, s)
|
self.assertEqual(out.sharding, s)
|
||||||
|
|
||||||
|
def test_dce_in_layouts(self):
|
||||||
|
def f(x, y):
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
shape = (8, 2)
|
||||||
|
inp = np.arange(math.prod(shape)).reshape(shape)
|
||||||
|
compiled = jax.jit(f).lower(inp, inp, _in_layouts=layout.AUTO,
|
||||||
|
_out_layouts=layout.AUTO).compile()
|
||||||
|
arg_layouts, _ = compiled._input_layouts()
|
||||||
|
out1 = compiled(inp, inp)
|
||||||
|
|
||||||
|
compiled2 = jax.jit(f).lower(inp, inp, _in_layouts=arg_layouts).compile()
|
||||||
|
out2 = compiled2(inp, inp)
|
||||||
|
|
||||||
|
for l1, l2 in safe_zip(arg_layouts, compiled2._input_layouts()[0]):
|
||||||
|
self.assertEqual(l1, l2)
|
||||||
|
|
||||||
|
self.assertArraysEqual(out1, out2)
|
||||||
|
|
||||||
|
# TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array
|
||||||
|
# and then pass that back into compiled.
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user