mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
generalize ravel_pytree to handle int types, add tests
This commit is contained in:
parent
d75becbf67
commit
5c6ff67e4e
@ -12,12 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .tree_util import tree_flatten, tree_unflatten
|
||||
from ._src.util import safe_zip
|
||||
from ._src.util import safe_zip, unzip2
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax.api import vjp
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
|
||||
zip = safe_zip
|
||||
|
||||
@ -26,18 +30,40 @@ def ravel_pytree(pytree):
|
||||
"""Ravel (i.e. flatten) a pytree of arrays down to a 1D array.
|
||||
|
||||
Args:
|
||||
pytree: a pytree to ravel.
|
||||
pytree: a pytree of arrays and scalars to ravel.
|
||||
|
||||
Returns:
|
||||
A pair where the first element is a 1D array representing the flattened and
|
||||
concatenated leaf values, and the second element is a callable for
|
||||
unflattening a 1D vector of the same length back to a pytree of of the same
|
||||
structure as the input ``pytree``.
|
||||
concatenated leaf values, with dtype determined by promoting the dtypes of
|
||||
leaf values, and the second element is a callable for unflattening a 1D
|
||||
vector of the same length back to a pytree of of the same structure as the
|
||||
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
|
||||
a convention a 1D empty array of dtype float32 is returned in the first
|
||||
component of the output.
|
||||
|
||||
For details on dtype promotion, see
|
||||
https://jax.readthedocs.io/en/latest/type_promotion.html.
|
||||
|
||||
"""
|
||||
leaves, treedef = tree_flatten(pytree)
|
||||
flat, unravel_list = vjp(_ravel_list, *leaves)
|
||||
flat, unravel_list = _ravel_list(leaves)
|
||||
unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
|
||||
return flat, unravel_pytree
|
||||
|
||||
def _ravel_list(*lst):
|
||||
return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([])
|
||||
def _ravel_list(lst):
|
||||
if not lst: return jnp.array([], jnp.float32), lambda _: []
|
||||
from_dtypes = [dtypes.dtype(l) for l in lst]
|
||||
to_dtype = dtypes.result_type(*from_dtypes)
|
||||
sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
|
||||
indices = np.cumsum(sizes)
|
||||
|
||||
def unravel(arr):
|
||||
chunks = jnp.split(arr, indices[:-1])
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
|
||||
return [lax.convert_element_type(chunk.reshape(shape), dtype)
|
||||
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]
|
||||
|
||||
ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
|
||||
raveled = jnp.concatenate([ravel(e) for e in lst])
|
||||
return raveled, unravel
|
||||
|
@ -20,6 +20,9 @@ from absl.testing import parameterized
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax import flatten_util
|
||||
from jax import dtypes
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def _dummy_func(*args, **kwargs):
|
||||
@ -274,5 +277,56 @@ class TreeTest(jtu.JaxTestCase):
|
||||
FlatCache({"a": [3, 4], "b": [5, 6]}))
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
class RavelUtilTest(jtu.JaxTestCase):
|
||||
|
||||
def testFloats(self):
|
||||
tree = [jnp.array([3.], jnp.float32),
|
||||
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, jnp.float32)
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
def testInts(self):
|
||||
tree = [jnp.array([3], jnp.int32),
|
||||
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, jnp.int32)
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
def testMixedFloatInt(self):
|
||||
tree = [jnp.array([3], jnp.int32),
|
||||
jnp.array([[1., 2.], [3., 4.]], jnp.float32)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.int32))
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
def testMixedIntBool(self):
|
||||
tree = [jnp.array([0], jnp.bool_),
|
||||
jnp.array([[1, 2], [3, 4]], jnp.int32)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.bool_, jnp.int32))
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
def testMixedFloatComplex(self):
|
||||
tree = [jnp.array([1.], jnp.float32),
|
||||
jnp.array([[1, 2 + 3j], [3, 4]], jnp.complex64)]
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, dtypes.promote_types(jnp.float32, jnp.complex64))
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
def testEmpty(self):
|
||||
tree = []
|
||||
raveled, unravel = flatten_util.ravel_pytree(tree)
|
||||
self.assertEqual(raveled.dtype, jnp.float32) # convention
|
||||
tree_ = unravel(raveled)
|
||||
self.assertAllClose(tree, tree_, atol=0., rtol=0.)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user