mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add a functools.partial subclass in tree_util
This commit is contained in:
parent
76eda746bd
commit
a45fc83eef
@ -35,6 +35,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
from collections import namedtuple
|
||||
import itertools as it
|
||||
from six.moves import reduce
|
||||
@ -253,3 +254,27 @@ def _namedtuple_node(t):
|
||||
NamedtupleNode = NodeType('namedtuple',
|
||||
lambda xs: (tuple(xs), type(xs)),
|
||||
lambda t, xs: t(*xs))
|
||||
|
||||
|
||||
class Partial(functools.partial):
|
||||
"""A version of functools.partial that works in pytrees.
|
||||
|
||||
Use it for partial function evaluation in a way that is compatibile with JAX's
|
||||
transformations, e.g., ``Partial(func, *args, **kwargs)``.
|
||||
|
||||
(You need to explicitly opt-in to this behavior because we didn't want to give
|
||||
functools.partial different semantics than normal function closures.)
|
||||
"""
|
||||
|
||||
def _partial_to_iterable(partial_):
|
||||
values = partial_.args + tuple(partial_.keywords.values())
|
||||
spec = (partial_.func, len(partial_.args), tuple(partial_.keywords.keys()))
|
||||
return values, spec
|
||||
|
||||
def _iterable_to_partial(spec, values):
|
||||
func, args_count, keys = spec
|
||||
args = values[:args_count]
|
||||
keywords = dict(zip(keys, values[args_count:]))
|
||||
return Partial(func, *args, **keywords)
|
||||
|
||||
register_pytree_node(Partial, _partial_to_iterable, _iterable_to_partial)
|
||||
|
59
tests/tree_util_tests.py
Normal file
59
tests/tree_util_tests.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
from jax import tree_util
|
||||
from jax import test_util as jtu
|
||||
|
||||
|
||||
def _dummy_func(*args, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
class TreeTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
((1, 2), ),
|
||||
([3], ),
|
||||
({'a': 1, 'b': 2}, )
|
||||
)
|
||||
def testRoundtrip(self, inputs):
|
||||
xs, tree = tree_util.tree_flatten(inputs)
|
||||
actual = tree_util.tree_unflatten(tree, xs)
|
||||
self.assertEqual(actual, inputs)
|
||||
|
||||
@parameterized.parameters(
|
||||
(tree_util.Partial(_dummy_func), ),
|
||||
(tree_util.Partial(_dummy_func, 1, 2), ),
|
||||
(tree_util.Partial(_dummy_func, x='a'), ),
|
||||
(tree_util.Partial(_dummy_func, 1, 2, 3, x=4, y=5), ),
|
||||
)
|
||||
def testRoundtrip(self, inputs):
|
||||
xs, tree = tree_util.tree_flatten(inputs)
|
||||
actual = tree_util.tree_unflatten(tree, xs)
|
||||
# functools.partial does not support equality comparisons:
|
||||
# https://stackoverflow.com/a/32786109/809705
|
||||
self.assertEqual(actual.func, inputs.func)
|
||||
self.assertEqual(actual.args, inputs.args)
|
||||
self.assertEqual(actual.keywords, inputs.keywords)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user