2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
2018-11-21 13:27:26 -08:00
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
from .core import pack
|
2019-01-28 08:37:49 -08:00
|
|
|
from .tree_util import (build_tree, process_pytree, tree_flatten,
|
|
|
|
tree_unflatten, leaf)
|
2018-11-17 18:03:33 -08:00
|
|
|
from .linear_util import transformation_with_aux
|
2019-01-03 16:14:30 -08:00
|
|
|
from .util import safe_map, unzip2, partial, curry
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2018-11-21 13:20:44 -08:00
|
|
|
map = safe_map
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
|
|
|
|
@curry
|
2019-01-06 11:59:33 -08:00
|
|
|
def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
|
|
|
|
try:
|
|
|
|
fun.__name__ = namestr.format(fun=get_name(wrapped))
|
|
|
|
fun.__module__ = get_module(wrapped)
|
|
|
|
fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
|
2019-02-14 11:00:40 -05:00
|
|
|
fun.__wrapped__ = wrapped
|
2019-01-06 11:59:33 -08:00
|
|
|
finally:
|
|
|
|
return fun
|
|
|
|
|
|
|
|
def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
|
|
|
|
def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
|
|
|
|
def get_doc(fun): return getattr(fun, "__doc__", "")
|
2019-01-03 16:14:30 -08:00
|
|
|
|
2019-04-10 22:09:14 -07:00
|
|
|
@transformation_with_aux
|
|
|
|
def pytree_fun_to_jaxtupletree_fun(args_trees, *args):
|
|
|
|
py_args = map(build_tree, args_trees, args)
|
|
|
|
ans = yield py_args, {}
|
|
|
|
yield pytree_to_jaxtupletree(ans)
|
2019-01-03 16:14:30 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
@transformation_with_aux
|
2019-04-10 22:09:14 -07:00
|
|
|
def pytree_fun_to_jaxtupletree_fun2(kwargs_tree, args_trees, kwargs, *args):
|
|
|
|
py_args = map(build_tree, args_trees, args)
|
|
|
|
py_kwargs = build_tree(kwargs_tree, kwargs)
|
|
|
|
ans = yield py_args, py_kwargs
|
2019-01-28 08:37:49 -08:00
|
|
|
yield pytree_to_jaxtupletree(ans)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
def apply_jaxtree_fun(fun, io_tree, *py_args):
|
2018-11-17 18:03:33 -08:00
|
|
|
in_trees_expected, out_tree = io_tree
|
2019-01-03 16:14:30 -08:00
|
|
|
args, in_trees = unzip2(map(pytree_to_jaxtupletree, py_args))
|
2018-11-17 18:03:33 -08:00
|
|
|
for i, (in_tree, expected) in enumerate(zip(in_trees, in_trees_expected)):
|
|
|
|
if in_tree != expected:
|
|
|
|
raise TypeError("Expected {}, got {}".format(expected, in_tree))
|
|
|
|
|
|
|
|
ans = fun(*args)
|
|
|
|
return build_tree(out_tree, ans)
|
|
|
|
|
2019-01-03 16:14:30 -08:00
|
|
|
pytree_to_jaxtupletree = partial(process_pytree, pack)
|
2019-01-28 08:37:49 -08:00
|
|
|
|
|
|
|
|
|
|
|
@transformation_with_aux
|
add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:
```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x) # relu
```
The partial evaluation logic works with tuples, so this works too:
```python
lax.cond(x < 0,
x, lambda x: (x, x, 1, 1, 1),
x, lambda x: (x, 1, x, 1, 2))
```
in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.
For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
|
|
|
def pytree_fun_to_flatjaxtuple_fun(in_trees, *args):
|
|
|
|
py_args = map(tree_unflatten, in_trees, args)
|
2019-04-10 22:09:14 -07:00
|
|
|
ans = yield py_args, {}
|
add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:
```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x) # relu
```
The partial evaluation logic works with tuples, so this works too:
```python
lax.cond(x < 0,
x, lambda x: (x, x, 1, 1, 1),
x, lambda x: (x, 1, x, 1, 2))
```
in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.
For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
|
|
|
yield pytree_to_flatjaxtuple(ans)
|
2019-01-28 09:00:02 -08:00
|
|
|
|
add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:
```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x) # relu
```
The partial evaluation logic works with tuples, so this works too:
```python
lax.cond(x < 0,
x, lambda x: (x, x, 1, 1, 1),
x, lambda x: (x, 1, x, 1, 2))
```
in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.
For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
|
|
|
def pytree_to_flatjaxtuple(pytree):
|
|
|
|
flat_ans, out_tree = tree_flatten(pytree)
|
|
|
|
return pack(flat_ans), out_tree
|