mark lax higher-order functions for stack trace filtering

This commit is contained in:
Roy Frostig 2021-02-23 20:08:10 -08:00
parent 4c179d1b66
commit 1283a9654b
2 changed files with 172 additions and 1 deletions

View File

@ -44,6 +44,7 @@ from jax.interpreters import batching
from jax.interpreters import masking
from jax.lib import xla_bridge as xb
from jax.lib import xla_client
from jax._src.traceback_util import api_boundary
from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip,
split_list, cache, extend_name_stack)
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
@ -137,6 +138,7 @@ def _fori_scan_body_fun(body_fun):
return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None
return scanned_fun
@api_boundary
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
@ -203,6 +205,7 @@ def fori_loop(lower, upper, body_fun, init_val):
return result
@api_boundary
def while_loop(cond_fun: Callable[[T], bool],
body_fun: Callable[[T], T],
init_val: T) -> T:
@ -549,6 +552,7 @@ batching.initial_style_batchers[while_p] = _while_loop_batching_rule
### cond and switch
@api_boundary
def switch(index, branches: Sequence[Callable], operand):
"""Apply exactly one of ``branches`` given by ``index``.
@ -696,6 +700,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
branches=(false_jaxpr, true_jaxpr), linear=linear)
return tree_unflatten(out_tree, out)
@api_boundary
@functools.wraps(_cond)
def cond(*args, **kwargs):
# detect an attempt to call the former, deprecated cond
@ -1132,6 +1137,7 @@ Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')
@api_boundary
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
@ -1888,6 +1894,7 @@ masking.masking_rules[scan_p] = _scan_masking_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
@api_boundary
def map(f, xs):
"""Map a function over leading array axes.
@ -2005,6 +2012,7 @@ def _split_root_args(args, const_lengths):
return _RootTuple(*params_list[:-1]), params_list[-1]
@api_boundary
def custom_root(f, initial_guess, solve, tangent_solve):
"""Differentiably solve for a roots of a function.
@ -2138,6 +2146,7 @@ def _check_shapes(func_name, expected_name, actual, expected):
f"got {actual_shapes} and {expected_shapes}")
@api_boundary
def custom_linear_solve(
matvec, b, solve, transpose_solve=None, symmetric=False):
"""Perform a matrix-free linear solve with implicitly defined gradients.
@ -2377,6 +2386,7 @@ def _interleave(a, b, axis):
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
lax.pad(b, lax._const(b, 0), b_pad))
@api_boundary
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
"""Performs a scan with an associative binary operation, in parallel.

View File

@ -18,7 +18,7 @@ import unittest
from absl.testing import absltest
from jax import grad, jit, vmap
from jax import grad, jit, vmap, lax
import jax.numpy as jnp
from jax import test_util as jtu
from jax._src import traceback_util
@ -121,6 +121,167 @@ class FilteredTracebackTest(jtu.JaxTestCase):
('outermost', 'return 2 + inbetween(x)'),
('inbetween', 'return 1 + grad(innermost)(x)')])
def test_lax_cond(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(_):
assert False
return ()
def f():
return lax.cond(True, err, lambda _: (), ())
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.cond(True, err, lambda _: (), ())'),
('err', 'assert False')])
def test_lax_switch(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(_):
assert False
return ()
def f():
branches = [lambda _: (), err, lambda _: ()]
return lax.switch(1, branches, ())
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.switch(1, branches, ())'),
('err', 'assert False')])
def test_lax_scan(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(*_):
assert False
return ()
def f():
return lax.scan(err, (), (), 3)
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.scan(err, (), (), 3)'),
('err', 'assert False')])
def test_lax_fori_loop(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(*_):
assert False
return ()
def f():
return lax.fori_loop(0, 3, err, ())
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.fori_loop(0, 3, err, ())'),
('err', 'assert False')])
def test_lax_while_loop(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(*_):
assert False
return ()
def f():
pred = lambda _: False
return lax.while_loop(pred, err, ())
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.while_loop(pred, err, ())'),
('err', 'assert False')])
def test_lax_map(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(_):
assert False
return ()
def f():
xs = jnp.ones(3)
return lax.map(err, xs)
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.map(err, xs)'),
('err', 'assert False')])
def test_lax_custom_root(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(*_):
assert False
return ()
def g(x): return (x - 1.) ** 2.
def solve(*_): return 1.
def f1():
return lax.custom_root(g, 0., err, solve)
def f2():
return lax.custom_root(g, 0., solve, err)
def f3():
return lax.custom_root(err, 0., solve, solve)
check_filtered_stack_trace(self, AssertionError, f1, [
('f1', 'return lax.custom_root(g, 0., err, solve)'),
('err', 'assert False')])
check_filtered_stack_trace(self, AssertionError, f2, [
('f2', 'return lax.custom_root(g, 0., solve, err)'),
('err', 'assert False')])
check_filtered_stack_trace(self, AssertionError, f3, [
('f3', 'return lax.custom_root(err, 0., solve, solve)'),
('err', 'assert False')])
def test_lax_custom_linear_solve(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(*_):
assert False
return ()
matvec = lambda v: v
solve = lambda mv, b: 1.
b = 1.
def f1():
return lax.custom_linear_solve(err, b, solve)
def f2():
return lax.custom_linear_solve(matvec, b, err)
check_filtered_stack_trace(self, AssertionError, f1, [
('f1', 'return lax.custom_linear_solve(err, b, solve)'),
('err', 'assert False')])
check_filtered_stack_trace(self, AssertionError, f2, [
('f2', 'return lax.custom_linear_solve(matvec, b, err)'),
('err', 'assert False')])
def test_lax_associative_scan(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')
def err(*_):
assert False
return ()
def f():
xs = jnp.arange(4.)
return lax.associative_scan(err, xs)
check_filtered_stack_trace(self, AssertionError, f, [
('f', 'return lax.associative_scan(err, xs)'),
('err', 'assert False')])
def test_cause_chain(self):
if not traceback_util.filtered_tracebacks_supported():
raise unittest.SkipTest('Filtered tracebacks not supported')