rocm_jax/jax/linear_util.py

213 lines
6.2 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2019-03-12 15:07:52 -04:00
"""
Utilities for defining linear functions composed with transformations.
"Linear" here is meant in the sense of linear types; that is, a linear function
may be called at most once.
For example:
from jax import linear_util as lu
# A transformation that scales its argument down and its result up.
@lu.transformation
def scale_transformer(scale, x):
ans = yield (x / scale,)
2019-03-12 15:25:13 -04:00
yield x * scale
2019-03-12 15:07:52 -04:00
def f(x):
return x + 1
g = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
g = scale_transformer(g, 2.0) # Scale inputs/outputs by 2.0
g = scale_transformer(g, 0.7) # Scale inputs/outputs further by 0.7.
print(g.call_wrapped(3.)) # Call the transformed function.
A `WrappedFun` object represents a function `f`, together with a
sequence of nested transformations that to be applied to the positional
arguments at call time and function return values at return time.
`WrappedFun` objects explicitly represent the set of transformations so that
they can be used as dictionary keys for memoization. `WrappedFun` objects
compare as equal only if they compute the same function.
Transformations are implemented as generators to save call stack frames.
A transformation's generator takes arguments `gen args + args`, and yields
a tuple of transformed arguments that should be passed to the wrapped
function. The result of the wrapped function is passed back to the generator
using gen.send(), and the generator yields the transformed results to pass
back to the caller.
Transformations can also return auxiliary data using the `transform_with_aux`
decorator. For example:
@lu.transformation_with_aux
def scale_transformer_aux(scale, x):
ans = yield (x / scale,)
2019-03-12 15:25:13 -04:00
yield (x * scale, "Auxiliary data: {}".format(x))
2019-03-12 15:07:52 -04:00
g = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`.
g, aux_thunk = scale_transformer_aux(g, 2.0) # Scale inputs/outputs by 2.0
print(g.call_wrapped(3.)) # Call the transformed function.
print(aux_thunk()) # Retrieves the auxiliary data computed during evaluation.
"""
2018-11-17 18:03:33 -08:00
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2018-11-17 18:03:33 -08:00
from .util import curry, partial, OrderedDict
2018-11-17 18:03:33 -08:00
def thunk(f):
store = Store()
def f_memoized():
if not store:
# TODO(dougalm): save/restore relevant environment state too
store.store(f())
return store.val
return f_memoized
class StoreException(Exception): pass
class Store(object):
def store(self, val):
assert not self, "Store occupied"
self._val = val
@property
def val(self):
if not self:
raise StoreException("Store empty")
return self._val
def __nonzero__(self):
return hasattr(self, '_val')
__bool__ = __nonzero__
2018-11-17 18:03:33 -08:00
@curry
def staged(f, *init_args):
store = Store()
def f_partial(*rest):
ans, aux = f(*(init_args + rest))
store.store(aux)
return ans
f_partial.__name__ = f.__name__ + "_staged"
return f_partial, thunk(lambda: store.val)
class WrappedFun(object):
2019-03-12 15:07:52 -04:00
"""Represents a function `f` to which a stack of `transforms` is to be applied.
Arguments:
f: the function to be transformed.
transforms: a list of `(gen, gen_args, out_store)` tuples representing
transformations to apply to `f.`
kwargs: keyword arguments to pass to `f`.
"""
2018-11-17 18:03:33 -08:00
def __init__(self, f, transforms, kwargs):
self.f = f
self.transforms = transforms
self.kwargs = kwargs
def wrap(self, *transformation):
return WrappedFun(self.f, [transformation] + self.transforms, self.kwargs)
def populate_stores(self, other):
for (_, _, self_store), (_, _, other_store) in zip(self.transforms,
other.transforms):
if self_store is not None:
self_store.store(other_store.val)
def call_wrapped(self, *args):
stack = []
for gen, gen_args, out_store in self.transforms:
gen = gen(*(gen_args + tuple(args)))
args = next(gen)
2018-11-17 18:03:33 -08:00
stack.append((gen, out_store))
del gen
ans = self.f(*args, **self.kwargs)
del args
while stack:
gen, out_store = stack.pop()
ans = gen.send(ans)
if out_store is not None:
ans, side = ans
out_store.store(side)
return ans
def __repr__(self):
def transform_to_str(x):
i, (gen, args, _) = x
return "{} : {} {}".format(i, fun_name(gen), fun_name(args))
transformation_stack = map(transform_to_str, enumerate(self.transforms))
return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
def hashable_payload(self):
return (self.f,
tuple((gen, tuple(gen_args)) for gen, gen_args, _ in self.transforms),
tuple(sorted(self.kwargs.items())))
def __hash__(self):
return hash(self.hashable_payload())
def __eq__(self, other):
return self.hashable_payload() == other.hashable_payload()
@curry
def transformation(gen, fun, *transformation_args):
return fun.wrap(gen, transformation_args, None)
@curry
def transformation_with_aux(gen, fun, *transformation_args):
out_store = Store()
out_thunk = lambda: out_store.val
return fun.wrap(gen, transformation_args, out_store), out_thunk
def fun_name(f):
try:
return f.__name__
except:
return str(f)
def wrap_init(f, kwargs={}):
2019-03-12 15:07:52 -04:00
"""Wraps function `f` as a `WrappedFun`, suitable for transformation."""
2018-11-17 18:03:33 -08:00
return WrappedFun(f, [], kwargs)
2019-01-15 10:32:58 -05:00
def memoize(call, max_size=4096):
cache = OrderedDict()
2018-11-17 18:03:33 -08:00
def memoized_fun(f, *args):
key = (f, args)
if key in cache:
ans, f_prev = cache[key]
cache.move_to_end(key)
2018-11-17 18:03:33 -08:00
f.populate_stores(f_prev)
else:
if len(cache) > max_size:
cache.popitem(last=False)
2018-11-17 18:03:33 -08:00
ans = call(f, *args)
cache[key] = (ans, f)
return ans
return memoized_fun