From c9a57e1b44875ab6b09c9ad92ba955f554def9bc Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 12 Jan 2023 22:57:49 -0800 Subject: [PATCH] Delete `jax.experimental.callback` PiperOrigin-RevId: 501760507 --- CHANGELOG.md | 3 + jax/BUILD | 7 - jax/experimental/callback.py | 289 ----------------------------------- tests/BUILD | 6 - tests/callback_test.py | 287 ---------------------------------- 5 files changed, 3 insertions(+), 589 deletions(-) delete mode 100644 jax/experimental/callback.py delete mode 100644 tests/callback_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d1877b28..bf67a1ae2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.2 +* Breaking changes + * Deleted `jax.experimental.callback` + ## jaxlib 0.4.2 ## jax 0.4.1 (Dec 13, 2022) diff --git a/jax/BUILD b/jax/BUILD index 0734f9c11..13c27fe7b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -187,13 +187,6 @@ pytype_library( deps = [":jax"], ) -pytype_library( - name = "callback", - srcs = ["experimental/callback.py"], - visibility = ["//visibility:public"], - deps = [":jax"], -) - # TODO(apaszke): Remove this target pytype_library( name = "maps", diff --git a/jax/experimental/callback.py b/jax/experimental/callback.py deleted file mode 100644 index 78945e001..000000000 --- a/jax/experimental/callback.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import itertools as it - -from typing import Any, Callable, Dict, Sequence, Union - -import jax.numpy as jnp - -from jax import core -from jax.core import Trace, Tracer, jaxpr_as_fun -from jax import lax -from jax import custom_derivatives as cd -from jax.interpreters import partial_eval as pe -from jax._src import linear_util as lu -from jax._src.util import safe_map, wraps, split_list -from jax._src.lax import control_flow as lcf - -import inspect -from jax._src.api_util import flatten_fun_nokwargs -from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_map - -map = safe_map - -### Public - -def callback_transform( - fun: Callable, callback: Callable, strip_calls: bool=False) -> Callable: - _check_callable(fun) - - @wraps(fun) - def wrapped_fun(*args): - args_flat, in_tree = tree_flatten(args) - f = lu.wrap_init(fun) - flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree) - out_flat = callback_fun(flat_fun, args_flat, callback, strip_calls) - return tree_unflatten(out_tree(), out_flat) - - return wrapped_fun - -### Example Transform - -def find_by_value(fun: Callable, queries) -> Callable: - def find_callback( - prim: core.Primitive, - vals: Sequence[core.Tracer], - params: Dict[str, Any]) -> Union[core.Tracer, Sequence[core.Tracer]]: - vals = prim.bind(*vals, **params) - _contains_query(vals, queries) - return vals - return callback_transform(fun, find_callback, True) - -def rewrite(fun: Callable, rules) -> Callable: - assert isinstance(rules, dict) - def rewrite_callback( - prim: core.Primitive, - vals: Sequence[core.Tracer], - params: Dict[str, Any]) -> Union[core.Tracer, Sequence[core.Tracer]]: - if prim in rules: - return rules[prim](*vals, **params) - return prim.bind(*vals, **params) - return callback_transform(fun, rewrite_callback) - -class FoundValue(Exception): - pass - -def _contains_query(vals, query): - if isinstance(query, tuple): - return map(partial(_contains_query, vals), query) - - if jnp.isnan(query): - if jnp.any(jnp.isnan(vals)): - raise FoundValue('NaN') - elif jnp.isinf(query): - if jnp.any(jnp.isinf(vals)): - raise FoundValue('Found Inf') - elif jnp.isscalar(query): - if jnp.any(vals == query): - raise FoundValue(str(query)) - else: - raise ValueError(f'Malformed Query: {query}') - -### Helper Functions - -def callback_fun(fun : lu.WrappedFun, in_vals, callback, strip_calls): - fun = callback_subtrace(fun) - fun = _callback_fun(fun, callback, strip_calls) - return fun.call_wrapped(*in_vals) - -@lu.transformation -def callback_subtrace(main, *in_vals, **params): - trace = main.with_cur_sublevel() - in_tracers = [CallbackTracer(trace, val) for val in in_vals] - outs = yield in_tracers, params - out_tracers = map(trace.full_raise, outs) - out_vals = [t.val for t in out_tracers] - yield out_vals - -@lu.transformation -def _callback_fun(callback, strip_calls, *in_vals, **params): - with core.new_main(CallbackTrace, callback=callback, - strip_calls=strip_calls) as main: - out_vals = yield (main,) + in_vals, params - del main - yield out_vals - -def callback_jaxpr(closed_jaxpr, callback, strip_calls): - fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr)) - fun = callback_subtrace(fun) - fun = _callback_fun(fun, callback, strip_calls) - avals_in = closed_jaxpr.in_avals - jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in) - return core.ClosedJaxpr(jaxpr_out, consts) - -def _check_callable(fun): - if not callable(fun): - raise TypeError(f"Expected a callable value, got {fun}") - if inspect.isgeneratorfunction(fun): - raise TypeError(f"Expected a function, got a generator function: {fun}") - -### Tracer - -class CallbackTracer(Tracer): - __slots__ = ['val'] - - def __init__(self, trace, val): - self._trace = trace - self.val = val - - @property - def aval(self): - return core.get_aval(self.val) - - def full_lower(self): - return self - -class CallbackTrace(Trace): - def __init__(self, *args, callback, strip_calls): - super().__init__(*args) - self.callback = callback - self.strip_calls = strip_calls - - def pure(self, val): - return CallbackTracer(self, val) - - def lift(self, val): - return CallbackTracer(self, val) - - def sublift(self, val): - return CallbackTracer(self, val.val) - - def process_primitive(self, primitive, tracers, params): - if primitive in custom_callback_rules: - return custom_callback_rules[primitive](self, *tracers, **params) - vals_in = [t.val for t in tracers] - vals_out = self.callback(primitive, vals_in, params) # type: ignore - if primitive.multiple_results: - return [CallbackTracer(self, val) for val in vals_out] - return CallbackTracer(self, vals_out) - - def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): - if self.strip_calls: # type: ignore - return f.call_wrapped(*tracers) - vals_in = [t.val for t in tracers] - f = callback_subtrace(f, self.main) - vals_out = call_primitive.bind(f, *vals_in, **params) - return [CallbackTracer(self, val) for val in vals_out] - - def process_custom_jvp_call(self, primitive, fun, jvp, tracers): - vals_in = [t.val for t in tracers] - fun = callback_subtrace(fun, self.main) - jvp = callback_subtrace(jvp, self.main) - out = primitive.bind(fun, jvp, *vals_in) - return safe_map(self.pure, out) - - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, - out_trees): - vals_in = [t.val for t in tracers] - fun = callback_subtrace(fun, self.main) - fwd = callback_subtrace(fwd, self.main) - bwd = callback_subtrace(bwd, self.main) - out = primitive.bind(fun, fwd, bwd, *vals_in, out_trees=out_trees) - return safe_map(self.pure, out) - -custom_callback_rules: Dict[Any, Any] = {} - -def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry, - jaxpr, linear, unroll): - const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry]) - carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers)) - const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) - - x_tracers = [t[0] for t in xs_tracers] - x_avals = [t.aval for t in x_tracers] - - body_fun = jaxpr_as_fun(jaxpr) - - def new_body(*vals): - out = body_fun(*vals) - out_carry, y = split_list(out, [num_carry]) - return out_carry, y - new_body = callback_transform(new_body, trace.callback, - strip_calls=trace.strip_calls) # type: ignore - in_tree = tree_structure(carry_avals + xs_avals) - new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr( - new_body, in_tree, tuple(carry_avals + x_avals)) - vals = tuple(it.chain(new_consts, carry_vals, xs_vals)) - out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length, - num_consts=len(new_consts), num_carry=num_carry, - jaxpr=new_jaxpr, linear=linear, unroll=unroll) - return safe_map(trace.pure, out_vals) - -custom_callback_rules[lax.scan_p] = _scan_callback_rule - - -def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr, - cond_nconsts, body_nconsts): - cond_const_tracers, body_const_tracers, init_tracers = split_list( - tracers, [cond_nconsts, body_nconsts]) - init_avals = safe_map(lambda x: x.aval, init_tracers) - cond_const_vals, body_const_vals, init_vals = tree_map( - lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) - - body_fun = jaxpr_as_fun(body_jaxpr) - cond_fun = jaxpr_as_fun(cond_jaxpr) - - def cond(*carry): - return cond_fun(*it.chain(cond_const_vals, carry)) - - def body(*carry): - return body_fun(*it.chain(body_const_vals, carry)) - - new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls) # type: ignore - new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls) # type: ignore - in_tree = tree_structure(init_avals) - - new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals)) - new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals)) - out = lcf.while_p.bind( - *it.chain(new_cond_consts, new_body_consts, init_vals), - cond_nconsts=len(new_cond_consts), - body_nconsts=len(new_body_consts), - cond_jaxpr=new_cond_jaxpr, - body_jaxpr=new_body_jaxpr) - return safe_map(trace.pure, out) - -custom_callback_rules[lax.while_p] = _while_callback_rule - -def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers, - fun_jaxpr, num_consts, **params): - main = trace.main - vals = [t.val for t in tracers] - - new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls) - if primitive == cd.custom_vjp_call_jaxpr_p: - thunk_name = 'fwd_jaxpr_thunk' - params['bwd'] = callback_subtrace(params['bwd'], main) - else: - raise NotImplementedError(primitive) - - thunk = params.pop(thunk_name) - @pe._memoize - def new_thunk(): - thunk_jaxpr = core.ClosedJaxpr(*thunk()) - closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls) - return closed_jaxpr.jaxpr, closed_jaxpr.literals - - params[thunk_name] = new_thunk - new_fun_jaxpr, new_consts = new_closed_jaxpr.jaxpr, new_closed_jaxpr.literals - closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(new_fun_jaxpr), ()) - new_num_consts = len(new_consts) + num_consts - out = primitive.bind(*it.chain(new_consts, vals), fun_jaxpr=closed_fun_jaxpr, - num_consts=new_num_consts, **params) - return safe_map(trace.pure, out) - -custom_callback_rules[cd.custom_vjp_call_jaxpr_p] = partial( - _custom_derivative_call_jaxpr_callback_rule, cd.custom_vjp_call_jaxpr_p) diff --git a/tests/BUILD b/tests/BUILD index c759bd553..14bce215f 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -66,12 +66,6 @@ jax_test( }, ) -jax_test( - name = "callback_test", - srcs = ["callback_test.py"], - deps = ["//jax:callback"], -) - jax_test( name = "core_test", srcs = ["core_test.py"], diff --git a/tests/callback_test.py b/tests/callback_test.py deleted file mode 100644 index ca4af5ce4..000000000 --- a/tests/callback_test.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright 2020 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest - -import jax -from jax._src import test_util as jtu -from jax.experimental.callback import find_by_value, rewrite, FoundValue -import jax.numpy as jnp -from jax import lax -from jax import jit -from jax import grad - -from jax.config import config -config.parse_flags_with_absl() - -class CallbackTest(jtu.JaxTestCase): - @jtu.sample_product(value=[jnp.inf, jnp.nan]) - def testFindByValueFound(self, value): - def f(x): - y = x ** 2 - z = 1 - y - r = 1 / z - return r * 0 - - with self.assertRaises(FoundValue): - find_by_value(f, value)(jnp.array([1.0, 2.0, 3.0])) - - @jtu.sample_product(value=[jnp.inf, jnp.nan]) - def testFindByValueFoundJIT(self, value): - def f(x): - @jit - def g(x): - y = x ** 2 - z = 1 - y - r = 1 / z - return r * 0 - return g(x) - with self.assertRaises(FoundValue): - find_by_value(f, value)(jnp.array([1.0, 2.0, 3.0])) - - @jtu.sample_product(value=[jnp.inf, jnp.nan]) - def testFindByValueNotFound(self, value): - def f(x): - y = x ** 2 - z = 1 - y - return z - - find_by_value(f, value)(jnp.array([1.0, 2.0, 3.0])) - - def testRewrite(self): - def f(x): - return x * 2 - - x = jnp.array([2.0, 4.0]) - self.assertAllClose(f(x), jnp.array([4.0, 8.0])) - - self.assertAllClose( - rewrite(f, {lax.mul_p: lambda x, y: x + y})(x), - jnp.array([4.0, 6.0])) - - def testRewriteJIT(self): - def f(x): - @jit - def g(x): - return x * 2 - return g(x) - - x = jnp.array([2.0, 4.0]) - self.assertAllClose(f(x), jnp.array([4.0, 8.0])) - - self.assertAllClose( - rewrite(f, {lax.mul_p: lambda x, y: x + y})(x), - jnp.array([4.0, 6.0])) - - def testRewriteWithCustomGradients(self): - def f(x): - return jax.nn.relu(x) - - x = jnp.array([2.0, 4.0]) - self.assertAllClose(f(x), jnp.array([2.0, 4.0])) - - self.assertAllClose( - rewrite(f, {})(x), - jnp.array([2.0, 4.0])) - - def testRewriteThroughScan(self): - def f(xs): - def body(carry, x): - carry = carry * 2. - return carry, x - 2. - return lax.scan(body, 1., xs) - - xs = jnp.arange(4.) - carry, ys = f(xs) - self.assertAllClose(carry, 16.) - self.assertAllClose(ys, jnp.arange(4.) - 2.) - - rewrites = { - lax.mul_p: lambda x, y: x + y, - lax.sub_p: lambda x, y: x / y - } - carry, ys = rewrite(f, rewrites)(xs) - self.assertAllClose(carry, 1. + 8.) - self.assertAllClose(ys, jnp.arange(4.) / 2.) - - - def testRewriteThroughWhile(self): - def f(x): - def cond(x): - return x < 5 - def body(x): - return x + 1 - return lax.while_loop(cond, body, x) - - x = 0 - self.assertAllClose(f(x), 5) - - rewrites = { - lax.add_p: lambda x, y: x + y + 100, - } - self.assertAllClose(rewrite(f, rewrites)(x), 101) - - rewrites = { - lax.lt_p: lambda x, y: x < y + 5 - } - self.assertAllClose(rewrite(f, rewrites)(x), 10) - - - def testRewriteThroughForLoop(self): - def f(x): - def body(i, x): - return x * i - return lax.fori_loop(1, 5, body, x) - - x = 1 - self.assertAllClose(f(x), 24) - - rewrites = { - lax.mul_p: lambda x, y: x + y - } - self.assertAllClose(rewrite(f, rewrites)(x), 11) - - def testRewriteThroughCustomVJP(self): - - @jax.custom_gradient - def f(x): - return x * 2, lambda g: g + x - - x = 2. - self.assertAllClose(f(x), 4.) - self.assertAllClose(grad(f)(x), 3.) - - rewrites = { - lax.mul_p: lambda x, y: x / y - } - g = rewrite(f, rewrites) - - self.assertAllClose(g(x), 1.) - self.assertAllClose(grad(g)(x), 3.) - - rewrites = { - lax.add_p: lambda x, y: x - y - } - g = rewrite(f, rewrites) - - self.assertAllClose(g(x), 4.) - self.assertAllClose(grad(g)(x), -1.) - - def testRewriteThroughCustomVJPInScan(self): - - @jax.custom_gradient - def foo(x): - return x * 2, lambda g: g + x - - def f(x): - out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1) - return out - - x = 2. - self.assertAllClose(f(x), 4.) - self.assertAllClose(grad(f)(x), 3.) - - rewrites = { - lax.mul_p: lambda x, y: x / y - } - g = rewrite(f, rewrites) - - self.assertAllClose(g(x), 1.) - self.assertAllClose(grad(g)(x), 3.) - - rewrites = { - lax.add_p: lambda x, y: x * y - } - g = rewrite(f, rewrites) - - self.assertAllClose(g(x), 4.) - self.assertAllClose(grad(g)(x), 2.) - - def testRewriteThroughCustomJVP(self): - - @jax.custom_jvp - def f(x): - return x + 2 - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - d, = tangents - return f(x), x * d - - x = 2. - self.assertAllClose(f(x), 4.) - f_primal, jvp = jax.jvp(f, (x,), (1.,)) - self.assertAllClose(f_primal, 4.) - self.assertAllClose(jvp, 2.) - self.assertAllClose(grad(f)(x), 2.) - - rewrites = { - lax.add_p: lambda x, y: x - y - } - g = rewrite(f, rewrites) - - self.assertAllClose(g(x), 0.) - g_primal, jvp = jax.jvp(g, (x,), (1.,)) - self.assertAllClose(g_primal, 0.) - self.assertAllClose(jvp, 2.) - self.assertAllClose(grad(g)(x), 2.) - - def testRewriteThroughCustomJVPInScan(self): - - @jax.custom_jvp - def foo(x): - return x + 2 - - @foo.defjvp - def foo_jvp(primals, tangents): - x, = primals - d, = tangents - return f(x), x * d - def f(x): - out, _ = lax.scan(lambda c, _: (foo(c), None), x, None, length=1) - return out - - x = 2. - self.assertAllClose(f(x), 4.) - f_primal, jvp = jax.jvp(f, (x,), (1.,)) - self.assertAllClose(f_primal, 4.) - self.assertAllClose(jvp, 2.) - self.assertAllClose(grad(f)(x), 2.) - - rewrites = { - lax.add_p: lambda x, y: x - y - } - g = rewrite(f, rewrites) - - self.assertAllClose(g(x), 0.) - g_primal, jvp = jax.jvp(g, (x,), (1.,)) - self.assertAllClose(g_primal, 0.) - self.assertAllClose(jvp, 2.) - self.assertAllClose(grad(g)(x), 2.) - - rewrites = { - lax.mul_p: lambda x, y: x + y - } - g = rewrite(f, rewrites) - - self.assertAllClose(g(x), 4.) - g_primal, jvp = jax.jvp(g, (x,), (1.,)) - self.assertAllClose(g_primal, 4.) - self.assertAllClose(jvp, 3.) - self.assertAllClose(grad(g)(x), 1.) - - -if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader())