mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
mark lax higher-order functions for stack trace filtering
This commit is contained in:
parent
4c179d1b66
commit
1283a9654b
@ -44,6 +44,7 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import masking
|
||||
from jax.lib import xla_bridge as xb
|
||||
from jax.lib import xla_client
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (partial, unzip2, unzip3, unzip4, safe_map, safe_zip,
|
||||
split_list, cache, extend_name_stack)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
|
||||
@ -137,6 +138,7 @@ def _fori_scan_body_fun(body_fun):
|
||||
return (lax.add(i, lax._const(i, 1)), upper, body_fun(i, x)), None
|
||||
return scanned_fun
|
||||
|
||||
@api_boundary
|
||||
def fori_loop(lower, upper, body_fun, init_val):
|
||||
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
|
||||
|
||||
@ -203,6 +205,7 @@ def fori_loop(lower, upper, body_fun, init_val):
|
||||
return result
|
||||
|
||||
|
||||
@api_boundary
|
||||
def while_loop(cond_fun: Callable[[T], bool],
|
||||
body_fun: Callable[[T], T],
|
||||
init_val: T) -> T:
|
||||
@ -549,6 +552,7 @@ batching.initial_style_batchers[while_p] = _while_loop_batching_rule
|
||||
|
||||
### cond and switch
|
||||
|
||||
@api_boundary
|
||||
def switch(index, branches: Sequence[Callable], operand):
|
||||
"""Apply exactly one of ``branches`` given by ``index``.
|
||||
|
||||
@ -696,6 +700,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, operand):
|
||||
branches=(false_jaxpr, true_jaxpr), linear=linear)
|
||||
return tree_unflatten(out_tree, out)
|
||||
|
||||
@api_boundary
|
||||
@functools.wraps(_cond)
|
||||
def cond(*args, **kwargs):
|
||||
# detect an attempt to call the former, deprecated cond
|
||||
@ -1132,6 +1137,7 @@ Carry = TypeVar('Carry')
|
||||
X = TypeVar('X')
|
||||
Y = TypeVar('Y')
|
||||
|
||||
@api_boundary
|
||||
def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
init: Carry,
|
||||
xs: X,
|
||||
@ -1888,6 +1894,7 @@ masking.masking_rules[scan_p] = _scan_masking_rule
|
||||
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
|
||||
|
||||
|
||||
@api_boundary
|
||||
def map(f, xs):
|
||||
"""Map a function over leading array axes.
|
||||
|
||||
@ -2005,6 +2012,7 @@ def _split_root_args(args, const_lengths):
|
||||
return _RootTuple(*params_list[:-1]), params_list[-1]
|
||||
|
||||
|
||||
@api_boundary
|
||||
def custom_root(f, initial_guess, solve, tangent_solve):
|
||||
"""Differentiably solve for a roots of a function.
|
||||
|
||||
@ -2138,6 +2146,7 @@ def _check_shapes(func_name, expected_name, actual, expected):
|
||||
f"got {actual_shapes} and {expected_shapes}")
|
||||
|
||||
|
||||
@api_boundary
|
||||
def custom_linear_solve(
|
||||
matvec, b, solve, transpose_solve=None, symmetric=False):
|
||||
"""Perform a matrix-free linear solve with implicitly defined gradients.
|
||||
@ -2377,6 +2386,7 @@ def _interleave(a, b, axis):
|
||||
return lax.add(lax.pad(a, lax._const(a, 0), a_pad),
|
||||
lax.pad(b, lax._const(b, 0), b_pad))
|
||||
|
||||
@api_boundary
|
||||
def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0):
|
||||
"""Performs a scan with an associative binary operation, in parallel.
|
||||
|
||||
|
@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax import grad, jit, vmap
|
||||
from jax import grad, jit, vmap, lax
|
||||
import jax.numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax._src import traceback_util
|
||||
@ -121,6 +121,167 @@ class FilteredTracebackTest(jtu.JaxTestCase):
|
||||
('outermost', 'return 2 + inbetween(x)'),
|
||||
('inbetween', 'return 1 + grad(innermost)(x)')])
|
||||
|
||||
def test_lax_cond(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
return lax.cond(True, err, lambda _: (), ())
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.cond(True, err, lambda _: (), ())'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_switch(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
branches = [lambda _: (), err, lambda _: ()]
|
||||
return lax.switch(1, branches, ())
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.switch(1, branches, ())'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_scan(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
return lax.scan(err, (), (), 3)
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.scan(err, (), (), 3)'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_fori_loop(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
return lax.fori_loop(0, 3, err, ())
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.fori_loop(0, 3, err, ())'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_while_loop(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
pred = lambda _: False
|
||||
return lax.while_loop(pred, err, ())
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.while_loop(pred, err, ())'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_map(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
xs = jnp.ones(3)
|
||||
return lax.map(err, xs)
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.map(err, xs)'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_custom_root(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def g(x): return (x - 1.) ** 2.
|
||||
def solve(*_): return 1.
|
||||
|
||||
def f1():
|
||||
return lax.custom_root(g, 0., err, solve)
|
||||
def f2():
|
||||
return lax.custom_root(g, 0., solve, err)
|
||||
def f3():
|
||||
return lax.custom_root(err, 0., solve, solve)
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f1, [
|
||||
('f1', 'return lax.custom_root(g, 0., err, solve)'),
|
||||
('err', 'assert False')])
|
||||
check_filtered_stack_trace(self, AssertionError, f2, [
|
||||
('f2', 'return lax.custom_root(g, 0., solve, err)'),
|
||||
('err', 'assert False')])
|
||||
check_filtered_stack_trace(self, AssertionError, f3, [
|
||||
('f3', 'return lax.custom_root(err, 0., solve, solve)'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_custom_linear_solve(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
matvec = lambda v: v
|
||||
solve = lambda mv, b: 1.
|
||||
b = 1.
|
||||
|
||||
def f1():
|
||||
return lax.custom_linear_solve(err, b, solve)
|
||||
def f2():
|
||||
return lax.custom_linear_solve(matvec, b, err)
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f1, [
|
||||
('f1', 'return lax.custom_linear_solve(err, b, solve)'),
|
||||
('err', 'assert False')])
|
||||
check_filtered_stack_trace(self, AssertionError, f2, [
|
||||
('f2', 'return lax.custom_linear_solve(matvec, b, err)'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_lax_associative_scan(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
||||
def err(*_):
|
||||
assert False
|
||||
return ()
|
||||
|
||||
def f():
|
||||
xs = jnp.arange(4.)
|
||||
return lax.associative_scan(err, xs)
|
||||
|
||||
check_filtered_stack_trace(self, AssertionError, f, [
|
||||
('f', 'return lax.associative_scan(err, xs)'),
|
||||
('err', 'assert False')])
|
||||
|
||||
def test_cause_chain(self):
|
||||
if not traceback_util.filtered_tracebacks_supported():
|
||||
raise unittest.SkipTest('Filtered tracebacks not supported')
|
||||
|
Loading…
x
Reference in New Issue
Block a user