Allow integer inputs/outputs of linear_transpose

Refine check to only allow float/complex -> float/complex or int -> int
functions (i.e. no mixing float/int inputs/outputs).
This commit is contained in:
Jamie Townsend 2021-03-26 11:14:43 +00:00
parent f65a327c76
commit 90f7e06bcc
2 changed files with 16 additions and 6 deletions

View File

@ -2005,14 +2005,18 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
in_avals = map(shaped_abstractify, primals_flat)
in_dtypes = map(dtypes.dtype, in_avals)
if any(not np.issubdtype(dtype, np.inexact) for dtype in in_dtypes):
raise TypeError("linear_transpose only supports float and complex inputs, "
f"but got {in_dtypes}")
in_pvals = map(pe.PartialVal.unknown, in_avals)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
instantiate=True)
out_avals, _ = unzip2(out_pvals)
out_dtypes = map(dtypes.dtype, out_avals)
if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes)
or all(dtypes.issubdtype(d, np.integer)
for d in in_dtypes + out_dtypes)):
raise TypeError("linear_transpose only supports [float or complex] -> "
"[float or complex], and integer -> integer functions, "
f"but got {in_dtypes} -> {out_dtypes}.")
def transposed_fun(out_cotangent):
out_cotangents, out_tree2 = tree_flatten(out_cotangent)

View File

@ -1126,11 +1126,17 @@ class APITest(jtu.JaxTestCase):
z, = transpose_fun(y)
self.assertArraysEqual(2 * y, z, check_dtypes=True)
def test_linear_transpose_integer(self):
f = lambda x: 2 * x
transpose = api.linear_transpose(f, 1)
actual, = transpose(3)
expected = 6
self.assertEqual(actual, expected)
def test_linear_transpose_error(self):
with self.assertRaisesRegex(
TypeError, "linear_transpose only supports float and complex inputs"):
api.linear_transpose(lambda x: x, 1)
TypeError, "linear_transpose only supports"):
api.linear_transpose(lambda x: 2. * x, 1)
transpose_fun = api.linear_transpose(lambda x: [x, x], 1.0)
with self.assertRaisesRegex(TypeError, "cotangent tree does not match"):
transpose_fun(1.0)