Add a functools.partial subclass in tree_util

This commit is contained in:
Stephan Hoyer 2019-07-09 11:38:23 -07:00
parent 76eda746bd
commit a45fc83eef
2 changed files with 84 additions and 0 deletions

View File

@ -35,6 +35,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from collections import namedtuple
import itertools as it
from six.moves import reduce
@ -253,3 +254,27 @@ def _namedtuple_node(t):
NamedtupleNode = NodeType('namedtuple',
lambda xs: (tuple(xs), type(xs)),
lambda t, xs: t(*xs))
class Partial(functools.partial):
"""A version of functools.partial that works in pytrees.
Use it for partial function evaluation in a way that is compatibile with JAX's
transformations, e.g., ``Partial(func, *args, **kwargs)``.
(You need to explicitly opt-in to this behavior because we didn't want to give
functools.partial different semantics than normal function closures.)
"""
def _partial_to_iterable(partial_):
values = partial_.args + tuple(partial_.keywords.values())
spec = (partial_.func, len(partial_.args), tuple(partial_.keywords.keys()))
return values, spec
def _iterable_to_partial(spec, values):
func, args_count, keys = spec
args = values[:args_count]
keywords = dict(zip(keys, values[args_count:]))
return Partial(func, *args, **keywords)
register_pytree_node(Partial, _partial_to_iterable, _iterable_to_partial)

59
tests/tree_util_tests.py Normal file
View File

@ -0,0 +1,59 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from absl.testing import parameterized
from jax import tree_util
from jax import test_util as jtu
def _dummy_func(*args, **kwargs):
return
class TreeTest(jtu.JaxTestCase):
@parameterized.parameters(
((1, 2), ),
([3], ),
({'a': 1, 'b': 2}, )
)
def testRoundtrip(self, inputs):
xs, tree = tree_util.tree_flatten(inputs)
actual = tree_util.tree_unflatten(tree, xs)
self.assertEqual(actual, inputs)
@parameterized.parameters(
(tree_util.Partial(_dummy_func), ),
(tree_util.Partial(_dummy_func, 1, 2), ),
(tree_util.Partial(_dummy_func, x='a'), ),
(tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ),
)
def testRoundtrip(self, inputs):
xs, tree = tree_util.tree_flatten(inputs)
actual = tree_util.tree_unflatten(tree, xs)
# functools.partial does not support equality comparisons:
# https://stackoverflow.com/a/32786109/809705
self.assertEqual(actual.func, inputs.func)
self.assertEqual(actual.args, inputs.args)
self.assertEqual(actual.keywords, inputs.keywords)
if __name__ == "__main__":
absltest.main()