mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
introduce lax.switch
This commit is contained in:
parent
dc4c9f0450
commit
6015a2a689
@ -298,6 +298,7 @@ from .lax_control_flow import (
|
||||
scan,
|
||||
scan_bind,
|
||||
scan_p,
|
||||
switch,
|
||||
while_loop,
|
||||
while_p,
|
||||
associative_scan,
|
||||
|
@ -2841,6 +2841,7 @@ ad.defjvp(clamp_p,
|
||||
g, _zeros(operand)),
|
||||
lambda g, min, operand, max:
|
||||
select(lt(max, operand), _brcast(g, operand), _zeros(operand)))
|
||||
batching.defbroadcasting(clamp_p)
|
||||
|
||||
|
||||
def _concatenate_shape_rule(*operands, **kwargs):
|
||||
|
@ -543,7 +543,71 @@ ad.primitive_transposes[while_p] = _while_transpose_error
|
||||
batching.primitive_batchers[while_p] = _while_loop_batching_rule
|
||||
|
||||
|
||||
### cond
|
||||
### cond and switch
|
||||
|
||||
def switch(index, branches: Sequence[Callable], operand):
|
||||
"""Apply exactly one of ``branches`` given by ``index``.
|
||||
|
||||
If ``index`` is out of bounds, it is clamped to within bounds.
|
||||
|
||||
Has the semantics of the following Python::
|
||||
|
||||
def switch(index, branches, operand):
|
||||
index = clamp(0, index, len(branches) - 1)
|
||||
return branches[index](operand)
|
||||
|
||||
Arguments:
|
||||
index: Integer scalar type, indicating which branch function to apply.
|
||||
branches: Sequence of functions (A -> B) to be applied based on `index`.
|
||||
operand: Operand (A) input to whichever branch is applied.
|
||||
"""
|
||||
if len(onp.shape(index)) != 0:
|
||||
raise TypeError(
|
||||
f"Branch index must be scalar, "
|
||||
f"got {index} of shape {onp.shape(index)}.")
|
||||
|
||||
try:
|
||||
index_dtype = dtypes.result_type(index)
|
||||
except TypeError as err:
|
||||
msg = f"Index type must be an integer, got {index}."
|
||||
raise TypeError(msg) from err
|
||||
|
||||
if index_dtype.kind not in 'iu':
|
||||
raise TypeError(
|
||||
f"Index type must be an integer, got {index} as {index_dtype}")
|
||||
|
||||
branches = tuple(branches)
|
||||
|
||||
if len(branches) == 0:
|
||||
raise ValueError("Empty branch sequence")
|
||||
elif len(branches) == 1:
|
||||
return branches[0](operand)
|
||||
|
||||
index = lax.convert_element_type(index, onp.int32)
|
||||
lo = onp.array(0, onp.int32)
|
||||
hi = onp.array(len(branches) - 1, onp.int32)
|
||||
index = lax.clamp(lo, index, hi)
|
||||
|
||||
if (jax.api._jit_is_disabled() and
|
||||
isinstance(core.get_aval(index), ConcreteArray)):
|
||||
return branches[int(index)](operand)
|
||||
|
||||
ops, ops_tree = tree_flatten((operand,))
|
||||
ops_avals = tuple(_map(_abstractify, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals)
|
||||
|
||||
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
|
||||
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
|
||||
out_trees[0], jaxprs[0].out_avals,
|
||||
out_tree, jaxpr.out_avals)
|
||||
|
||||
linear = (False,) * (len(consts) + len(ops))
|
||||
out = cond_p.bind(
|
||||
index, *consts, *ops, branches=jaxprs, linear=linear)
|
||||
return tree_unflatten(out_trees[0], out)
|
||||
|
||||
|
||||
def cond(*args, **kwargs):
|
||||
"""Conditionally apply ``true_fun`` or ``false_fun``.
|
||||
@ -671,7 +735,7 @@ def _select_tree(indices, branch_vals):
|
||||
mid = onp.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
|
||||
return lax.select(lax.lt(indices, mid),
|
||||
_select_tree(indices, branch_vals[:mid]),
|
||||
_select_tree(indices, branch_vals[mid:]))
|
||||
_select_tree(indices - mid, branch_vals[mid:]))
|
||||
|
||||
def _cond_index_bcast_and_select_tree(indices, branch_vals):
|
||||
if all(core.get_aval(x) is core.abstract_unit for x in branch_vals):
|
||||
|
@ -39,6 +39,27 @@ from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
# Some tests are useful for testing both lax.cond and lax.switch. This function
|
||||
# provides a lax.cond-compatible interface to a two-branch lax.switch. Several
|
||||
# tests in this file are parameterized such that they either call into lax.cond
|
||||
# or into this function.
|
||||
def cond_via_switch(pred, true_fun, false_fun, op, *args):
|
||||
if len(args) > 0:
|
||||
assert len(args) == 1
|
||||
true_op, _true_fun, false_op, _false_fun = true_fun, false_fun, op, args[0]
|
||||
op = (false_op, true_op)
|
||||
false_fun = lambda op: _false_fun(op[0])
|
||||
true_fun = lambda op: _true_fun(op[1])
|
||||
index = lax.convert_element_type(pred, np.int32)
|
||||
return lax.switch(index, [false_fun, true_fun], op)
|
||||
|
||||
|
||||
COND_IMPLS = [
|
||||
(lax.cond, 'cond'),
|
||||
(cond_via_switch, 'switch'),
|
||||
]
|
||||
|
||||
|
||||
def while_loop_reference(cond, body, carry):
|
||||
while cond(carry):
|
||||
carry = body(carry)
|
||||
@ -508,6 +529,56 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertEqual(fun(4), cfun(4))
|
||||
self.assertEqual(fun(4), (8, 16))
|
||||
|
||||
def testSwitch(self):
|
||||
def branch(x):
|
||||
y = lax.mul(2, x)
|
||||
return y, lax.mul(2, y)
|
||||
|
||||
branches = [lambda x: (x, x),
|
||||
branch,
|
||||
lambda x: (x, -x)]
|
||||
|
||||
def fun(x):
|
||||
if x <= 0:
|
||||
return branches[0](x)
|
||||
elif x == 1:
|
||||
return branches[1](x)
|
||||
else:
|
||||
return branches[2](x)
|
||||
|
||||
def cfun(x):
|
||||
return lax.switch(x, branches, x)
|
||||
|
||||
self.assertEqual(fun(-1), cfun(-1))
|
||||
self.assertEqual(fun(0), cfun(0))
|
||||
self.assertEqual(fun(1), cfun(1))
|
||||
self.assertEqual(fun(2), cfun(2))
|
||||
self.assertEqual(fun(3), cfun(3))
|
||||
|
||||
cfun = api.jit(cfun)
|
||||
|
||||
self.assertEqual(fun(-1), cfun(-1))
|
||||
self.assertEqual(fun(0), cfun(0))
|
||||
self.assertEqual(fun(1), cfun(1))
|
||||
self.assertEqual(fun(2), cfun(2))
|
||||
self.assertEqual(fun(3), cfun(3))
|
||||
|
||||
def testOneBranchSwitch(self):
|
||||
branch = lambda x: -x
|
||||
f = lambda i, x: lax.switch(i, [branch], x)
|
||||
x = 7.
|
||||
self.assertEqual(f(-1, x), branch(x))
|
||||
self.assertEqual(f(0, x), branch(x))
|
||||
self.assertEqual(f(1, x), branch(x))
|
||||
cf = api.jit(f)
|
||||
self.assertEqual(cf(-1, x), branch(x))
|
||||
self.assertEqual(cf(0, x), branch(x))
|
||||
self.assertEqual(cf(1, x), branch(x))
|
||||
cf = api.jit(f, static_argnums=0)
|
||||
self.assertEqual(cf(-1, x), branch(x))
|
||||
self.assertEqual(cf(0, x), branch(x))
|
||||
self.assertEqual(cf(1, x), branch(x))
|
||||
|
||||
def testIssue1379(self):
|
||||
|
||||
def fun(pred):
|
||||
@ -527,7 +598,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
for f in [fun, cfun]:
|
||||
self.assertRaises(TypeError, f, pred)
|
||||
|
||||
def testNestedCond(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testNestedCond(self, cond):
|
||||
def fun(x):
|
||||
if x < 2:
|
||||
return lax.mul(2, x)
|
||||
@ -539,12 +613,12 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
@api.jit
|
||||
def cfun(x):
|
||||
return lax.cond(
|
||||
return cond(
|
||||
lax.lt(x, 2),
|
||||
lambda x: lax.mul(2, x),
|
||||
lambda x: lax.cond(lax.lt(x, 5),
|
||||
x, lambda x: lax.mul(3, x),
|
||||
4, lambda y: lax.mul(y, x)),
|
||||
lambda x: cond(lax.lt(x, 5),
|
||||
x, lambda x: lax.mul(3, x),
|
||||
4, lambda y: lax.mul(y, x)),
|
||||
x)
|
||||
|
||||
self.assertEqual(cfun(1), 2)
|
||||
@ -579,6 +653,33 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
lambda fop: jnp.float32(1.),
|
||||
1.)
|
||||
|
||||
def testSwitchErrors(self):
|
||||
"""Test typing error messages for switch."""
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("Index type must be an integer, got <function")):
|
||||
lax.switch(lambda x: True, [lambda _: 2., lambda _: 3.], 1.)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("Index type must be an integer, got foo.")):
|
||||
lax.switch("foo", [lambda _: 2., lambda _: 3.], 1.)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("Branch index must be scalar, got (1.0, 1.0) of shape (2,).")):
|
||||
lax.switch((1., 1.), [lambda _: 2., lambda _: 3.], 1.)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
re.escape("Empty branch sequence")):
|
||||
lax.switch(0, [], 1.)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
re.escape("branch 0 and 1 outputs must have same type structure, got * and PyTreeDef(tuple, [*,*]).")):
|
||||
lax.switch(1, [lambda _: 2., lambda _: (3., 3.)], 1.)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
TypeError,
|
||||
"branch 0 and 1 outputs must have identical types, got\n"
|
||||
"ShapedArray(float32[1])\n"
|
||||
"and\n"
|
||||
"ShapedArray(float32[])."):
|
||||
lax.switch(1, [lambda _: jnp.array([1.], jnp.float32),
|
||||
lambda _: jnp.float32(1.)],
|
||||
1.)
|
||||
|
||||
def testCondOneBranchConstant(self):
|
||||
def fun(x):
|
||||
if x < 3:
|
||||
@ -666,6 +767,60 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" in str(jaxpr)
|
||||
|
||||
def testSwitchBatched(self):
|
||||
def fun(index, x, y, z):
|
||||
branches = [lambda xyz: xyz[0],
|
||||
lambda xyz: lax.neg(xyz[1]),
|
||||
lambda xyz: lax.sign(xyz[2])]
|
||||
return lax.switch(index, branches, (x, y, z))
|
||||
|
||||
# these cases stay as cond
|
||||
x = jnp.array(0)
|
||||
y = jnp.array([1, 2])
|
||||
z = jnp.array([3, 4])
|
||||
w = jnp.array(9)
|
||||
ans = api.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
|
||||
expected = np.array([1, 2])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
x = jnp.array(1)
|
||||
ans = api.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, 0, None)))(x, y, z, w)
|
||||
expected = np.array([-3, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
fun = api.jit(fun)
|
||||
ans = api.vmap(fun, (None, 0, 0, None))(x, y, z, w)
|
||||
expected = np.array([-3, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
z = jnp.array(5)
|
||||
ans = api.vmap(fun, (None, 0, None, None))(x, y, z, w)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (None, 0, None, None)))(x, y, z, w)
|
||||
expected = np.array([-5, -5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" not in str(jaxpr)
|
||||
|
||||
|
||||
# these cases become select
|
||||
x = jnp.array([0, 1])
|
||||
ans = api.vmap(fun, (0, 0, None, None))(x, y, z, w)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun, (0, 0, None, None)))(x, y, z, w)
|
||||
expected = np.array([1, -5])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" in str(jaxpr)
|
||||
|
||||
z = jnp.array([3, 4])
|
||||
w = jnp.array([9, 9])
|
||||
ans = api.vmap(fun)(x, y, z, w)
|
||||
jaxpr = api.make_jaxpr(api.vmap(fun))(x, y, z, w)
|
||||
expected = np.array([1, -4])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
assert "select" in str(jaxpr)
|
||||
|
||||
def testCondJVP(self):
|
||||
def fun_ref(x):
|
||||
if x < 3:
|
||||
@ -692,7 +847,38 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
||||
|
||||
def testCondJVP2(self):
|
||||
def testSwitchJVP(self):
|
||||
def branch(x):
|
||||
y = 2 * x
|
||||
return y, 2 * y
|
||||
|
||||
branches = [lambda x: (x, x),
|
||||
branch,
|
||||
lambda x: (x, -x)]
|
||||
|
||||
def fun_ref(x):
|
||||
idx = x // 1
|
||||
if idx <= 0:
|
||||
return branches[0](x)
|
||||
elif idx == 1:
|
||||
return branches[1](x)
|
||||
else:
|
||||
return branches[2](x)
|
||||
|
||||
def fun(x):
|
||||
idx = lax.convert_element_type(x // 1, np.int32)
|
||||
return lax.switch(idx, branches, x)
|
||||
|
||||
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
|
||||
ans = api.jvp(fun, (x,), (x,))
|
||||
expected = api.jvp(fun_ref, (x,), (x,))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
jtu.check_grads(fun, (x,), order=2, modes=["fwd"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondJVP2(self, cond):
|
||||
def fun_ref(x):
|
||||
if x < 3:
|
||||
return 2.
|
||||
@ -700,7 +886,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return 2. * x
|
||||
|
||||
def fun(x):
|
||||
return lax.cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
|
||||
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
|
||||
|
||||
x = 3.14
|
||||
ans = api.jvp(fun, (x,), (x,))
|
||||
@ -733,13 +919,40 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
||||
|
||||
def testCondGrad2(self):
|
||||
def testSwitchGrad(self):
|
||||
branches = [lambda x: 3. * x,
|
||||
lambda x: jnp.sin(x),
|
||||
lambda x: -x]
|
||||
|
||||
def f_ref(x):
|
||||
idx = x // 1
|
||||
if idx <= 0:
|
||||
return branches[0](x)
|
||||
elif idx == 1:
|
||||
return branches[1](x)
|
||||
else:
|
||||
return branches[2](x)
|
||||
|
||||
def f(x):
|
||||
idx = lax.convert_element_type(x // 1, np.int32)
|
||||
return lax.switch(idx, branches, x)
|
||||
|
||||
for x in [-0.7, 0.7, 1.7, 2.7, 3.7]:
|
||||
ans = api.grad(f)(x)
|
||||
expected = api.grad(f_ref)(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondGrad2(self, cond):
|
||||
def f_ref(x):
|
||||
z = jnp.array([1., 2.]) * x if x[0] < 2 else jnp.sin(x)
|
||||
return z.sum()
|
||||
|
||||
def _f(x):
|
||||
return lax.cond(
|
||||
return cond(
|
||||
x[0] < 2,
|
||||
lambda x: jnp.array([1., 2.]) * x,
|
||||
lambda x: jnp.sin(x),
|
||||
@ -760,7 +973,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (x,), order=2, modes=["fwd", "rev"],
|
||||
rtol={jnp.float32: 1e-2, jnp.float64: 2e-3})
|
||||
|
||||
def testCondGrad3(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondGrad3(self, cond):
|
||||
def fun_ref(x):
|
||||
if x < 3:
|
||||
return 2.
|
||||
@ -768,7 +984,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return 2. * x
|
||||
|
||||
def fun(x):
|
||||
return lax.cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
|
||||
return cond(x < 3, (), lambda _: 2., x, lambda x: 2. * x)
|
||||
|
||||
x = 3.14
|
||||
ans = api.grad(fun)(x)
|
||||
@ -782,7 +998,10 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
jtu.check_grads(fun, (x,), order=2, modes=["fwd", "rev"])
|
||||
|
||||
def testCondGrad4(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondGrad4(self, cond):
|
||||
def fun_ref(x, y):
|
||||
if x < 3:
|
||||
return 2. * jnp.sin(y)
|
||||
@ -790,7 +1009,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
return 2. * jnp.cos(x)
|
||||
|
||||
def fun(x, y):
|
||||
return lax.cond(
|
||||
return cond(
|
||||
x < 3,
|
||||
(), lambda _: 2. * jnp.sin(y),
|
||||
x, lambda x: 2. * x)
|
||||
@ -818,13 +1037,45 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, jnp.sin(4.), check_dtypes=False)
|
||||
self.assertAllClose(f_lin(2.), jnp.cos(4.) * 2., check_dtypes=False)
|
||||
|
||||
def testCondLinearize2(self):
|
||||
def testSwitchLinearize(self):
|
||||
branches = [lambda x: 3. * x,
|
||||
lambda x: jnp.sin(x),
|
||||
lambda x: -x]
|
||||
def f(x):
|
||||
idx = lax.convert_element_type(x // 1, np.int32)
|
||||
return lax.switch(idx, branches, x)
|
||||
|
||||
# branch 0
|
||||
y, f_lin = api.linearize(f, -1.)
|
||||
self.assertAllClose(y, -3., check_dtypes=False)
|
||||
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
||||
y, f_lin = api.linearize(f, 0.)
|
||||
self.assertAllClose(y, 0., check_dtypes=False)
|
||||
self.assertAllClose(f_lin(2.), 6., check_dtypes=False)
|
||||
|
||||
# branch 1
|
||||
y, f_lin = api.linearize(f, 1.)
|
||||
self.assertAllClose(y, jnp.sin(1.), check_dtypes=False)
|
||||
self.assertAllClose(f_lin(2.), jnp.cos(1.) * 2., check_dtypes=False)
|
||||
|
||||
# branch 2
|
||||
y, f_lin = api.linearize(f, 2.)
|
||||
self.assertAllClose(y, -2., check_dtypes=False)
|
||||
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
|
||||
y, f_lin = api.linearize(f, 3.)
|
||||
self.assertAllClose(y, -3., check_dtypes=False)
|
||||
self.assertAllClose(f_lin(2.), -2., check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondLinearize2(self, cond):
|
||||
def f_ref(x):
|
||||
z = jnp.array([1., 2.]) * x if x[0] < 2 else jnp.cos(jnp.sin(x))
|
||||
return z.sum()
|
||||
|
||||
def f(x):
|
||||
return lax.cond(
|
||||
return cond(
|
||||
x[0] < 2,
|
||||
lambda x: jnp.array([1., 2.]) * x,
|
||||
lambda x: jnp.cos(jnp.sin(x)),
|
||||
@ -859,11 +1110,26 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = f(4.)
|
||||
self.assertAllClose(y, expected, check_dtypes=False)
|
||||
|
||||
def testCondJitDisabled(self):
|
||||
def testSwitchJit(self):
|
||||
branches = [lambda x: 3. * x,
|
||||
lambda x: jnp.sin(x),
|
||||
lambda x: -x]
|
||||
def f(x):
|
||||
idx = lax.convert_element_type(x // 1, np.int32)
|
||||
return lax.switch(idx, branches, x)
|
||||
for x in [-1., 0., 1., 2., 3.]:
|
||||
y = api.jit(f)(x)
|
||||
expected = f(x)
|
||||
self.assertAllClose(y, expected, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondJitDisabled(self, cond):
|
||||
def f_ref(x):
|
||||
return 3. * x if x < 2 else jnp.sin(x)
|
||||
def f(x):
|
||||
return lax.cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
||||
return cond(x < 2, lambda x: 3. * x, lambda x: jnp.sin(x), x)
|
||||
|
||||
with api.disable_jit():
|
||||
y = f(1.)
|
||||
@ -875,12 +1141,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = f(1.)
|
||||
self.assertAllClose(y, expected, check_dtypes=False)
|
||||
|
||||
def testCondWithConsts(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondWithConsts(self, cond):
|
||||
def f(x):
|
||||
return lax.cond(x < 2,
|
||||
lambda x: np.array([1., 2.]) * x,
|
||||
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
||||
x)
|
||||
return cond(x < 2,
|
||||
lambda x: np.array([1., 2.]) * x,
|
||||
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
||||
x)
|
||||
|
||||
def f_ref(x):
|
||||
if x < 2:
|
||||
@ -895,12 +1164,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = f_ref(4.)
|
||||
self.assertAllClose(y, expected, check_dtypes=False)
|
||||
|
||||
def testCondJitWithConsts(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondJitWithConsts(self, cond):
|
||||
def f(x):
|
||||
return lax.cond(x < 2,
|
||||
lambda x: np.array([1., 2.]) * x,
|
||||
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
||||
x)
|
||||
return cond(x < 2,
|
||||
lambda x: np.array([1., 2.]) * x,
|
||||
lambda x: np.array([3., 4.]) * jnp.sin(x),
|
||||
x)
|
||||
|
||||
y = api.jit(f)(1.)
|
||||
expected = f(1.)
|
||||
@ -909,12 +1181,15 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
expected = f(4.)
|
||||
self.assertAllClose(y, expected, check_dtypes=False)
|
||||
|
||||
def testCondVmapGrad(self):
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_{name}", "cond": cond}
|
||||
for cond, name in COND_IMPLS)
|
||||
def testCondVmapGrad(self, cond):
|
||||
# https://github.com/google/jax/issues/2264
|
||||
def f_1(x): return x ** 2
|
||||
def f_2(x): return x ** 3
|
||||
|
||||
def f(x): return lax.cond(x > 0, f_1, f_2, x)
|
||||
def f(x): return cond(x > 0, f_1, f_2, x)
|
||||
def g(x): return jnp.where(x > 0, f_1(x), f_2(x))
|
||||
|
||||
x = jnp.linspace(-1, 1, 20)
|
||||
|
Loading…
x
Reference in New Issue
Block a user