introduce lax.switch

This commit is contained in:
Roy Frostig 2020-06-01 19:07:11 -07:00 committed by Roy Frostig
parent dc4c9f0450
commit 6015a2a689
4 changed files with 372 additions and 31 deletions

View File

@ -298,6 +298,7 @@ from .lax_control_flow import (
scan,
scan_bind,
scan_p,
switch,
while_loop,
while_p,
associative_scan,

View File

@ -2841,6 +2841,7 @@ ad.defjvp(clamp_p,
g, _zeros(operand)),
lambda g, min, operand, max:
select(lt(max, operand), _brcast(g, operand), _zeros(operand)))
batching.defbroadcasting(clamp_p)
def _concatenate_shape_rule(*operands, **kwargs):

View File

@ -543,7 +543,71 @@ ad.primitive_transposes[while_p] = _while_transpose_error
batching.primitive_batchers[while_p] = _while_loop_batching_rule
### cond
### cond and switch
def switch(index, branches: Sequence[Callable], operand):
"""Apply exactly one of ``branches`` given by ``index``.
If ``index`` is out of bounds, it is clamped to within bounds.
Has the semantics of the following Python::
def switch(index, branches, operand):
index = clamp(0, index, len(branches) - 1)
return branches[index](operand)
Arguments:
index: Integer scalar type, indicating which branch function to apply.
branches: Sequence of functions (A -> B) to be applied based on `index`.
operand: Operand (A) input to whichever branch is applied.
"""
if len(onp.shape(index)) != 0:
raise TypeError(
f"Branch index must be scalar, "
f"got {index} of shape {onp.shape(index)}.")
try:
index_dtype = dtypes.result_type(index)
except TypeError as err:
msg = f"Index type must be an integer, got {index}."
raise TypeError(msg) from err
if index_dtype.kind not in 'iu':
raise TypeError(
f"Index type must be an integer, got {index} as {index_dtype}")
branches = tuple(branches)
if len(branches) == 0:
raise ValueError("Empty branch sequence")
elif len(branches) == 1:
return branches[0](operand)
index = lax.convert_element_type(index, onp.int32)
lo = onp.array(0, onp.int32)
hi = onp.array(len(branches) - 1, onp.int32)
index = lax.clamp(lo, index, hi)
if (jax.api._jit_is_disabled() and
isinstance(core.get_aval(index), ConcreteArray)):
return branches[int(index)](operand)
ops, ops_tree = tree_flatten((operand,))
ops_avals = tuple(_map(_abstractify, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals)
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)
linear = (False,) * (len(consts) + len(ops))
out = cond_p.bind(
index, *consts, *ops, branches=jaxprs, linear=linear)
return tree_unflatten(out_trees[0], out)
def cond(*args, **kwargs):
"""Conditionally apply ``true_fun`` or ``false_fun``.
@ -671,7 +735,7 @@ def _select_tree(indices, branch_vals):
mid = onp.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
return lax.select(lax.lt(indices, mid),
_select_tree(indices, branch_vals[:mid]),
_select_tree(indices, branch_vals[mid:]))
_select_tree(indices - mid, branch_vals[mid:]))
def _cond_index_bcast_and_select_tree(indices, branch_vals):
if all(core.get_aval(x) is core.abstract_unit for x in branch_vals):

View File

@ -39,6 +39,27 @@ from jax.config import config
config.parse_flags_with_absl()
# Some tests are useful for testing both lax.cond and lax.switch. This function
# provides a lax.cond-compatible interface to a two-branch lax.switch. Several
# tests in this file are parameterized such that they either call into lax.cond
# or into this function.
def cond_via_switch(pred, true_fun, false_fun, op, *args):
if len(args) > 0:
assert len(args) == 1
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
op = (false_op, true_op)
false_fun = lambda op: _false_fun(op[0])
true_fun = lambda op: _true_fun(op[1])
index = lax.convert_element_type(pred, np.int32)
return lax.switch(index, [false_fun, true_fun], op)
COND_IMPLS = [
(lax.cond, 'cond'),
(cond_via_switch, 'switch'),
]
def while_loop_reference(cond, body, carry):
while cond(carry):
carry = body(carry)
@ -508,6 +529,56 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertEqual(fun(4), cfun(4))
self.assertEqual(fun(4), (8, 16))
def testSwitch(self):
def branch(x):
y = lax.mul(2, x)
return y, lax.mul(2, y)
branches = [lambda x: (x, x),
branch,
lambda x: (x, -x)]
def fun(x):
if x <= 0:
return branches[0](x)
elif x == 1:
return branches[1](x)
else:
return branches[2](x)
def cfun(x):
return lax.switch(x, branches, x)
self.assertEqual(fun(-1), cfun(-1))
self.assertEqual(fun(0), cfun(0))
self.assertEqual(fun(1), cfun(1))
self.assertEqual(fun(2), cfun(2))
self.assertEqual(fun(3), cfun(3))
cfun = api.jit(cfun)
self.assertEqual(fun(-1), cfun(-1))
self.assertEqual(fun(0), cfun(0))
self.assertEqual(fun(1), cfun(1))
self.assertEqual(fun(2), cfun(2))
self.assertEqual(fun(3), cfun(3))
def testOneBranchSwitch(self):
branch = lambda x: -x
f = lambda i, x: lax.switch(i, [branch], x)
x = 7.
self.assertEqual(f(-1, x), branch(x))
self.assertEqual(f(0, x), branch(x))
self.assertEqual(f(1, x), branch(x))
cf = api.jit(f)
self.assertEqual(cf(-1, x), branch(x))
self.assertEqual(cf(0, x), branch(x))
self.assertEqual(cf(1, x), branch(x))
cf = api.jit(f, static_argnums=0)
self.assertEqual(cf(-1, x), branch(x))
self.assertEqual(cf(0, x), branch(x))
self.assertEqual(cf(1, x), branch(x))
def testIssue1379(self):
def fun(pred):
@ -527,7 +598,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
for f in [fun, cfun]:
self.assertRaises(TypeError, f, pred)
def testNestedCond(self):
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testNestedCond(self, cond):
def fun(x):
if x < 2:
return lax.mul(2, x)
@ -539,12 +613,12 @@ class LaxControlFlowTest(jtu.JaxTestCase):
@api.jit
def cfun(x):
return lax.cond(
return cond(
lax.lt(x, 2),
lambda x: lax.mul(2, x),
lambda x: lax.cond(lax.lt(x, 5),
x, lambda x: lax.mul(3, x),
4, lambda y: lax.mul(y, x)),
lambda x: cond(lax.lt(x, 5),
x, lambda x: lax.mul(3, x),
4, lambda y: lax.mul(y, x)),
x)
self.assertEqual(cfun(1), 2)
@ -579,6 +653,33 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lambda fop: jnp.float32(1.),
1.)
def testSwitchErrors(self):
"""Test typing error messages for switch."""
with self.assertRaisesRegex(TypeError,
re.escape("Index type must be an integer, got <function")):
lax.switch(lambda x: True, [lambda _: 2., lambda _: 3.], 1.)
with self.assertRaisesRegex(TypeError,
re.escape("Index type must be an integer, got foo.")):
lax.switch("foo", [lambda _: 2., lambda _: 3.], 1.)
with self.assertRaisesRegex(TypeError,
re.escape("Branch index must be scalar, got (1.0, 1.0) of shape (2,).")):
lax.switch((1., 1.), [lambda _: 2., lambda _: 3.], 1.)
with self.assertRaisesRegex(ValueError,
re.escape("Empty branch sequence")):
lax.switch(0, [], 1.)
with self.assertRaisesRegex(TypeError,
re.escape("branch 0 and 1 outputs must have same type structure, got * and PyTreeDef(tuple, [*,*]).")):
lax.switch(1, [lambda _: 2., lambda _: (3., 3.)], 1.)
with self.assertRaisesWithLiteralMatch(
TypeError,
"branch 0 and 1 outputs must have identical types, got\n"
"ShapedArray(float32[1])\n"
"and\n"
"ShapedArray(float32[])."):
lax.switch(1, [lambda _: jnp.array([1.], jnp.float32),
lambda _: jnp.float32(1.)],
1.)
def testCondOneBranchConstant(self):
def fun(x):
if x < 3:
@ -666,6 +767,60 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
def testSwitchBatched(self):
def fun(index, x, y, z):
branches = [lambda xyz: xyz[0],
lambda xyz: lax.neg(xyz[1]),
lambda xyz: lax.sign(xyz[2])]
return lax.switch(index, branches, (x, y, z))
# these cases stay as cond
x = jnp.array(0)
y = jnp.array([1, 2])
z = jnp.array([3, 4])
w = jnp.array(9)
ans = api.vmap(fun, (None, 0, 0, None))(x, y, z, w)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
expected = np.array([1, 2])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
x = jnp.array(1)
ans = api.vmap(fun, (None, 0, 0, None))(x, y, z, w)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
expected = np.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
fun = api.jit(fun)
ans = api.vmap(fun, (None, 0, 0, None))(x, y, z, w)
expected = np.array([-3, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
z = jnp.array(5)
ans = api.vmap(fun, (None, 0, None, None))(x, y, z, w)
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None, None)))(x, y, z, w)
expected = np.array([-5, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" not in str(jaxpr)
# these cases become select
x = jnp.array([0, 1])
ans = api.vmap(fun, (0, 0, None, None))(x, y, z, w)
jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None, None)))(x, y, z, w)
expected = np.array([1, -5])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
z = jnp.array([3, 4])
w = jnp.array([9, 9])
ans = api.vmap(fun)(x, y, z, w)
jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z, w)
expected = np.array([1, -4])
self.assertAllClose(ans, expected, check_dtypes=False)
assert "select" in str(jaxpr)
def testCondJVP(self):
def fun_ref(x):
if x < 3:
@ -692,7 +847,38 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
def testCondJVP2(self):
def testSwitchJVP(self):
def branch(x):
y = 2 * x
return y, 2 * y
branches = [lambda x: (x, x),
branch,
lambda x: (x, -x)]
def fun_ref(x):
idx = x // 1
if idx <= 0:
return branches[0](x)
elif idx == 1:
return branches[1](x)
else:
return branches[2](x)
def fun(x):
idx = lax.convert_element_type(x // 1, np.int32)
return lax.switch(idx, branches, x)
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
ans = api.jvp(fun, (x,), (x,))
expected = api.jvp(fun_ref, (x,), (x,))
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondJVP2(self, cond):
def fun_ref(x):
if x < 3:
return 2.
@ -700,7 +886,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return 2. * x
def fun(x):
return lax.cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
x = 3.14
ans = api.jvp(fun, (x,), (x,))
@ -733,13 +919,40 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
def testCondGrad2(self):
def testSwitchGrad(self):
branches = [lambda x: 3. * x,
lambda x: jnp.sin(x),
lambda x: -x]
def f_ref(x):
idx = x // 1
if idx <= 0:
return branches[0](x)
elif idx == 1:
return branches[1](x)
else:
return branches[2](x)
def f(x):
idx = lax.convert_element_type(x // 1, np.int32)
return lax.switch(idx, branches, x)
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
ans = api.grad(f)(x)
expected = api.grad(f_ref)(x)
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondGrad2(self, cond):
def f_ref(x):
z = jnp.array([1., 2.]) * x if x[0] < 2 else jnp.sin(x)
return z.sum()
def _f(x):
return lax.cond(
return cond(
x[0] < 2,
lambda x: jnp.array([1., 2.]) * x,
lambda x: jnp.sin(x),
@ -760,7 +973,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"],
rtol={jnp.float32: 1e-2, jnp.float64: 2e-3})
def testCondGrad3(self):
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondGrad3(self, cond):
def fun_ref(x):
if x < 3:
return 2.
@ -768,7 +984,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return 2. * x
def fun(x):
return lax.cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
x = 3.14
ans = api.grad(fun)(x)
@ -782,7 +998,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(fun, (x,), order=2, modes=["fwd", "rev"])
def testCondGrad4(self):
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondGrad4(self, cond):
def fun_ref(x, y):
if x < 3:
return 2. * jnp.sin(y)
@ -790,7 +1009,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
return 2. * jnp.cos(x)
def fun(x, y):
return lax.cond(
return cond(
x < 3,
(), lambda _: 2. * jnp.sin(y),
x, lambda x: 2. * x)
@ -818,13 +1037,45 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(y, jnp.sin(4.), check_dtypes=False)
self.assertAllClose(f_lin(2.), jnp.cos(4.) * 2., check_dtypes=False)
def testCondLinearize2(self):
def testSwitchLinearize(self):
branches = [lambda x: 3. * x,
lambda x: jnp.sin(x),
lambda x: -x]
def f(x):
idx = lax.convert_element_type(x // 1, np.int32)
return lax.switch(idx, branches, x)
# branch 0
y, f_lin = api.linearize(f, -1.)
self.assertAllClose(y, -3., check_dtypes=False)
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
y, f_lin = api.linearize(f, 0.)
self.assertAllClose(y, 0., check_dtypes=False)
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
# branch 1
y, f_lin = api.linearize(f, 1.)
self.assertAllClose(y, jnp.sin(1.), check_dtypes=False)
self.assertAllClose(f_lin(2.), jnp.cos(1.) * 2., check_dtypes=False)
# branch 2
y, f_lin = api.linearize(f, 2.)
self.assertAllClose(y, -2., check_dtypes=False)
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
y, f_lin = api.linearize(f, 3.)
self.assertAllClose(y, -3., check_dtypes=False)
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondLinearize2(self, cond):
def f_ref(x):
z = jnp.array([1., 2.]) * x if x[0] < 2 else jnp.cos(jnp.sin(x))
return z.sum()
def f(x):
return lax.cond(
return cond(
x[0] < 2,
lambda x: jnp.array([1., 2.]) * x,
lambda x: jnp.cos(jnp.sin(x)),
@ -859,11 +1110,26 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = f(4.)
self.assertAllClose(y, expected, check_dtypes=False)
def testCondJitDisabled(self):
def testSwitchJit(self):
branches = [lambda x: 3. * x,
lambda x: jnp.sin(x),
lambda x: -x]
def f(x):
idx = lax.convert_element_type(x // 1, np.int32)
return lax.switch(idx, branches, x)
for x in [-1., 0., 1., 2., 3.]:
y = api.jit(f)(x)
expected = f(x)
self.assertAllClose(y, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondJitDisabled(self, cond):
def f_ref(x):
return 3. * x if x < 2 else jnp.sin(x)
def f(x):
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
return cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
with api.disable_jit():
y = f(1.)
@ -875,12 +1141,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = f(1.)
self.assertAllClose(y, expected, check_dtypes=False)
def testCondWithConsts(self):
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondWithConsts(self, cond):
def f(x):
return lax.cond(x < 2,
lambda x: np.array([1., 2.]) * x,
lambda x: np.array([3., 4.]) * jnp.sin(x),
x)
return cond(x < 2,
lambda x: np.array([1., 2.]) * x,
lambda x: np.array([3., 4.]) * jnp.sin(x),
x)
def f_ref(x):
if x < 2:
@ -895,12 +1164,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = f_ref(4.)
self.assertAllClose(y, expected, check_dtypes=False)
def testCondJitWithConsts(self):
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondJitWithConsts(self, cond):
def f(x):
return lax.cond(x < 2,
lambda x: np.array([1., 2.]) * x,
lambda x: np.array([3., 4.]) * jnp.sin(x),
x)
return cond(x < 2,
lambda x: np.array([1., 2.]) * x,
lambda x: np.array([3., 4.]) * jnp.sin(x),
x)
y = api.jit(f)(1.)
expected = f(1.)
@ -909,12 +1181,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
expected = f(4.)
self.assertAllClose(y, expected, check_dtypes=False)
def testCondVmapGrad(self):
@parameterized.named_parameters(
{"testcase_name": f"_{name}", "cond": cond}
for cond, name in COND_IMPLS)
def testCondVmapGrad(self, cond):
# https://github.com/google/jax/issues/2264
def f_1(x): return x ** 2
def f_2(x): return x ** 3
def f(x): return lax.cond(x > 0, f_1, f_2, x)
def f(x): return cond(x > 0, f_1, f_2, x)
def g(x): return jnp.where(x > 0, f_1(x), f_2(x))
x = jnp.linspace(-1, 1, 20)