Delete jax.experimental.callback

PiperOrigin-RevId: 501760507
This commit is contained in:
Sharad Vikram 2023-01-12 22:57:49 -08:00 committed by jax authors
parent e21c29476d
commit c9a57e1b44
5 changed files with 3 additions and 589 deletions

View File

@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.2
* Breaking changes
* Deleted `jax.experimental.callback`
## jaxlib 0.4.2
## jax 0.4.1 (Dec 13, 2022)

View File

@ -187,13 +187,6 @@ pytype_library(
deps = [":jax"],
)
pytype_library(
name = "callback",
srcs = ["experimental/callback.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)
# TODO(apaszke): Remove this target
pytype_library(
name = "maps",

View File

@ -1,289 +0,0 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools as it
from typing import Any, Callable, Dict, Sequence, Union
import jax.numpy as jnp
from jax import core
from jax.core import Trace, Tracer, jaxpr_as_fun
from jax import lax
from jax import custom_derivatives as cd
from jax.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src.util import safe_map, wraps, split_list
from jax._src.lax import control_flow as lcf
import inspect
from jax._src.api_util import flatten_fun_nokwargs
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_map
map = safe_map
### Public
def callback_transform(
fun: Callable, callback: Callable, strip_calls: bool=False) -> Callable:
_check_callable(fun)
@wraps(fun)
def wrapped_fun(*args):
args_flat, in_tree = tree_flatten(args)
f = lu.wrap_init(fun)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
out_flat = callback_fun(flat_fun, args_flat, callback, strip_calls)
return tree_unflatten(out_tree(), out_flat)
return wrapped_fun
### Example Transform
def find_by_value(fun: Callable, queries) -> Callable:
def find_callback(
prim: core.Primitive,
vals: Sequence[core.Tracer],
params: Dict[str, Any]) -> Union[core.Tracer, Sequence[core.Tracer]]:
vals = prim.bind(*vals, **params)
_contains_query(vals, queries)
return vals
return callback_transform(fun, find_callback, True)
def rewrite(fun: Callable, rules) -> Callable:
assert isinstance(rules, dict)
def rewrite_callback(
prim: core.Primitive,
vals: Sequence[core.Tracer],
params: Dict[str, Any]) -> Union[core.Tracer, Sequence[core.Tracer]]:
if prim in rules:
return rules[prim](*vals, **params)
return prim.bind(*vals, **params)
return callback_transform(fun, rewrite_callback)
class FoundValue(Exception):
pass
def _contains_query(vals, query):
if isinstance(query, tuple):
return map(partial(_contains_query, vals), query)
if jnp.isnan(query):
if jnp.any(jnp.isnan(vals)):
raise FoundValue('NaN')
elif jnp.isinf(query):
if jnp.any(jnp.isinf(vals)):
raise FoundValue('Found Inf')
elif jnp.isscalar(query):
if jnp.any(vals == query):
raise FoundValue(str(query))
else:
raise ValueError(f'Malformed Query: {query}')
### Helper Functions
def callback_fun(fun : lu.WrappedFun, in_vals, callback, strip_calls):
fun = callback_subtrace(fun)
fun = _callback_fun(fun, callback, strip_calls)
return fun.call_wrapped(*in_vals)
@lu.transformation
def callback_subtrace(main, *in_vals, **params):
trace = main.with_cur_sublevel()
in_tracers = [CallbackTracer(trace, val) for val in in_vals]
outs = yield in_tracers, params
out_tracers = map(trace.full_raise, outs)
out_vals = [t.val for t in out_tracers]
yield out_vals
@lu.transformation
def _callback_fun(callback, strip_calls, *in_vals, **params):
with core.new_main(CallbackTrace, callback=callback,
strip_calls=strip_calls) as main:
out_vals = yield (main,) + in_vals, params
del main
yield out_vals
def callback_jaxpr(closed_jaxpr, callback, strip_calls):
fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr))
fun = callback_subtrace(fun)
fun = _callback_fun(fun, callback, strip_calls)
avals_in = closed_jaxpr.in_avals
jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts)
def _check_callable(fun):
if not callable(fun):
raise TypeError(f"Expected a callable value, got {fun}")
if inspect.isgeneratorfunction(fun):
raise TypeError(f"Expected a function, got a generator function: {fun}")
### Tracer
class CallbackTracer(Tracer):
__slots__ = ['val']
def __init__(self, trace, val):
self._trace = trace
self.val = val
@property
def aval(self):
return core.get_aval(self.val)
def full_lower(self):
return self
class CallbackTrace(Trace):
def __init__(self, *args, callback, strip_calls):
super().__init__(*args)
self.callback = callback
self.strip_calls = strip_calls
def pure(self, val):
return CallbackTracer(self, val)
def lift(self, val):
return CallbackTracer(self, val)
def sublift(self, val):
return CallbackTracer(self, val.val)
def process_primitive(self, primitive, tracers, params):
if primitive in custom_callback_rules:
return custom_callback_rules[primitive](self, *tracers, **params)
vals_in = [t.val for t in tracers]
vals_out = self.callback(primitive, vals_in, params) # type: ignore
if primitive.multiple_results:
return [CallbackTracer(self, val) for val in vals_out]
return CallbackTracer(self, vals_out)
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
if self.strip_calls: # type: ignore
return f.call_wrapped(*tracers)
vals_in = [t.val for t in tracers]
f = callback_subtrace(f, self.main)
vals_out = call_primitive.bind(f, *vals_in, **params)
return [CallbackTracer(self, val) for val in vals_out]
def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
vals_in = [t.val for t in tracers]
fun = callback_subtrace(fun, self.main)
jvp = callback_subtrace(jvp, self.main)
out = primitive.bind(fun, jvp, *vals_in)
return safe_map(self.pure, out)
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees):
vals_in = [t.val for t in tracers]
fun = callback_subtrace(fun, self.main)
fwd = callback_subtrace(fwd, self.main)
bwd = callback_subtrace(bwd, self.main)
out = primitive.bind(fun, fwd, bwd, *vals_in, out_trees=out_trees)
return safe_map(self.pure, out)
custom_callback_rules: Dict[Any, Any] = {}
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr, linear, unroll):
const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry])
carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers))
const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
x_tracers = [t[0] for t in xs_tracers]
x_avals = [t.aval for t in x_tracers]
body_fun = jaxpr_as_fun(jaxpr)
def new_body(*vals):
out = body_fun(*vals)
out_carry, y = split_list(out, [num_carry])
return out_carry, y
new_body = callback_transform(new_body, trace.callback,
strip_calls=trace.strip_calls) # type: ignore
in_tree = tree_structure(carry_avals + xs_avals)
new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
new_body, in_tree, tuple(carry_avals + x_avals))
vals = tuple(it.chain(new_consts, carry_vals, xs_vals))
out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length,
num_consts=len(new_consts), num_carry=num_carry,
jaxpr=new_jaxpr, linear=linear, unroll=unroll)
return safe_map(trace.pure, out_vals)
custom_callback_rules[lax.scan_p] = _scan_callback_rule
def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr,
cond_nconsts, body_nconsts):
cond_const_tracers, body_const_tracers, init_tracers = split_list(
tracers, [cond_nconsts, body_nconsts])
init_avals = safe_map(lambda x: x.aval, init_tracers)
cond_const_vals, body_const_vals, init_vals = tree_map(
lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers))
body_fun = jaxpr_as_fun(body_jaxpr)
cond_fun = jaxpr_as_fun(cond_jaxpr)
def cond(*carry):
return cond_fun(*it.chain(cond_const_vals, carry))
def body(*carry):
return body_fun(*it.chain(body_const_vals, carry))
new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls) # type: ignore
new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls) # type: ignore
in_tree = tree_structure(init_avals)
new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals))
new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals))
out = lcf.while_p.bind(
*it.chain(new_cond_consts, new_body_consts, init_vals),
cond_nconsts=len(new_cond_consts),
body_nconsts=len(new_body_consts),
cond_jaxpr=new_cond_jaxpr,
body_jaxpr=new_body_jaxpr)
return safe_map(trace.pure, out)
custom_callback_rules[lax.while_p] = _while_callback_rule
def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
fun_jaxpr, num_consts, **params):
main = trace.main
vals = [t.val for t in tracers]
new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls)
if primitive == cd.custom_vjp_call_jaxpr_p:
thunk_name = 'fwd_jaxpr_thunk'
params['bwd'] = callback_subtrace(params['bwd'], main)
else:
raise NotImplementedError(primitive)
thunk = params.pop(thunk_name)
@pe._memoize
def new_thunk():
thunk_jaxpr = core.ClosedJaxpr(*thunk())
closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls)
return closed_jaxpr.jaxpr, closed_jaxpr.literals
params[thunk_name] = new_thunk
new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals
closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ())
new_num_consts = len(new_consts) + num_consts
out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr,
num_consts=new_num_consts, **params)
return safe_map(trace.pure, out)
custom_callback_rules[cd.custom_vjp_call_jaxpr_p] = partial(
_custom_derivative_call_jaxpr_callback_rule, cd.custom_vjp_call_jaxpr_p)

View File

@ -66,12 +66,6 @@ jax_test(
},
)
jax_test(
name = "callback_test",
srcs = ["callback_test.py"],
deps = ["//jax:callback"],
)
jax_test(
name = "core_test",
srcs = ["core_test.py"],

View File

@ -1,287 +0,0 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax.experimental.callback import find_by_value, rewrite, FoundValue
import jax.numpy as jnp
from jax import lax
from jax import jit
from jax import grad
from jax.config import config
config.parse_flags_with_absl()
class CallbackTest(jtu.JaxTestCase):
@jtu.sample_product(value=[jnp.inf, jnp.nan])
def testFindByValueFound(self, value):
def f(x):
y = x ** 2
z = 1 - y
r = 1 / z
return r * 0
with self.assertRaises(FoundValue):
find_by_value(f, value)(jnp.array([1.0, 2.0, 3.0]))
@jtu.sample_product(value=[jnp.inf, jnp.nan])
def testFindByValueFoundJIT(self, value):
def f(x):
@jit
def g(x):
y = x ** 2
z = 1 - y
r = 1 / z
return r * 0
return g(x)
with self.assertRaises(FoundValue):
find_by_value(f, value)(jnp.array([1.0, 2.0, 3.0]))
@jtu.sample_product(value=[jnp.inf, jnp.nan])
def testFindByValueNotFound(self, value):
def f(x):
y = x ** 2
z = 1 - y
return z
find_by_value(f, value)(jnp.array([1.0, 2.0, 3.0]))
def testRewrite(self):
def f(x):
return x * 2
x = jnp.array([2.0, 4.0])
self.assertAllClose(f(x), jnp.array([4.0, 8.0]))
self.assertAllClose(
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
jnp.array([4.0, 6.0]))
def testRewriteJIT(self):
def f(x):
@jit
def g(x):
return x * 2
return g(x)
x = jnp.array([2.0, 4.0])
self.assertAllClose(f(x), jnp.array([4.0, 8.0]))
self.assertAllClose(
rewrite(f, {lax.mul_p: lambda x, y: x + y})(x),
jnp.array([4.0, 6.0]))
def testRewriteWithCustomGradients(self):
def f(x):
return jax.nn.relu(x)
x = jnp.array([2.0, 4.0])
self.assertAllClose(f(x), jnp.array([2.0, 4.0]))
self.assertAllClose(
rewrite(f, {})(x),
jnp.array([2.0, 4.0]))
def testRewriteThroughScan(self):
def f(xs):
def body(carry, x):
carry = carry * 2.
return carry, x - 2.
return lax.scan(body, 1., xs)
xs = jnp.arange(4.)
carry, ys = f(xs)
self.assertAllClose(carry, 16.)
self.assertAllClose(ys, jnp.arange(4.) - 2.)
rewrites = {
lax.mul_p: lambda x, y: x + y,
lax.sub_p: lambda x, y: x / y
}
carry, ys = rewrite(f, rewrites)(xs)
self.assertAllClose(carry, 1. + 8.)
self.assertAllClose(ys, jnp.arange(4.) / 2.)
def testRewriteThroughWhile(self):
def f(x):
def cond(x):
return x < 5
def body(x):
return x + 1
return lax.while_loop(cond, body, x)
x = 0
self.assertAllClose(f(x), 5)
rewrites = {
lax.add_p: lambda x, y: x + y + 100,
}
self.assertAllClose(rewrite(f, rewrites)(x), 101)
rewrites = {
lax.lt_p: lambda x, y: x < y + 5
}
self.assertAllClose(rewrite(f, rewrites)(x), 10)
def testRewriteThroughForLoop(self):
def f(x):
def body(i, x):
return x * i
return lax.fori_loop(1, 5, body, x)
x = 1
self.assertAllClose(f(x), 24)
rewrites = {
lax.mul_p: lambda x, y: x + y
}
self.assertAllClose(rewrite(f, rewrites)(x), 11)
def testRewriteThroughCustomVJP(self):
@jax.custom_gradient
def f(x):
return x * 2, lambda g: g + x
x = 2.
self.assertAllClose(f(x), 4.)
self.assertAllClose(grad(f)(x), 3.)
rewrites = {
lax.mul_p: lambda x, y: x / y
}
g = rewrite(f, rewrites)
self.assertAllClose(g(x), 1.)
self.assertAllClose(grad(g)(x), 3.)
rewrites = {
lax.add_p: lambda x, y: x - y
}
g = rewrite(f, rewrites)
self.assertAllClose(g(x), 4.)
self.assertAllClose(grad(g)(x), -1.)
def testRewriteThroughCustomVJPInScan(self):
@jax.custom_gradient
def foo(x):
return x * 2, lambda g: g + x
def f(x):
out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1)
return out
x = 2.
self.assertAllClose(f(x), 4.)
self.assertAllClose(grad(f)(x), 3.)
rewrites = {
lax.mul_p: lambda x, y: x / y
}
g = rewrite(f, rewrites)
self.assertAllClose(g(x), 1.)
self.assertAllClose(grad(g)(x), 3.)
rewrites = {
lax.add_p: lambda x, y: x * y
}
g = rewrite(f, rewrites)
self.assertAllClose(g(x), 4.)
self.assertAllClose(grad(g)(x), 2.)
def testRewriteThroughCustomJVP(self):
@jax.custom_jvp
def f(x):
return x + 2
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
d, = tangents
return f(x), x * d
x = 2.
self.assertAllClose(f(x), 4.)
f_primal, jvp = jax.jvp(f, (x,), (1.,))
self.assertAllClose(f_primal, 4.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(f)(x), 2.)
rewrites = {
lax.add_p: lambda x, y: x - y
}
g = rewrite(f, rewrites)
self.assertAllClose(g(x), 0.)
g_primal, jvp = jax.jvp(g, (x,), (1.,))
self.assertAllClose(g_primal, 0.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(g)(x), 2.)
def testRewriteThroughCustomJVPInScan(self):
@jax.custom_jvp
def foo(x):
return x + 2
@foo.defjvp
def foo_jvp(primals, tangents):
x, = primals
d, = tangents
return f(x), x * d
def f(x):
out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1)
return out
x = 2.
self.assertAllClose(f(x), 4.)
f_primal, jvp = jax.jvp(f, (x,), (1.,))
self.assertAllClose(f_primal, 4.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(f)(x), 2.)
rewrites = {
lax.add_p: lambda x, y: x - y
}
g = rewrite(f, rewrites)
self.assertAllClose(g(x), 0.)
g_primal, jvp = jax.jvp(g, (x,), (1.,))
self.assertAllClose(g_primal, 0.)
self.assertAllClose(jvp, 2.)
self.assertAllClose(grad(g)(x), 2.)
rewrites = {
lax.mul_p: lambda x, y: x + y
}
g = rewrite(f, rewrites)
self.assertAllClose(g(x), 4.)
g_primal, jvp = jax.jvp(g, (x,), (1.,))
self.assertAllClose(g_primal, 4.)
self.assertAllClose(jvp, 3.)
self.assertAllClose(grad(g)(x), 1.)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())