mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow integer inputs/outputs of linear_transpose
Refine check to only allow float/complex -> float/complex or int -> int functions (i.e. no mixing float/int inputs/outputs).
This commit is contained in:
parent
f65a327c76
commit
90f7e06bcc
@ -2005,14 +2005,18 @@ def linear_transpose(fun: Callable, *primals) -> Callable:
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
|
||||
in_avals = map(shaped_abstractify, primals_flat)
|
||||
in_dtypes = map(dtypes.dtype, in_avals)
|
||||
if any(not np.issubdtype(dtype, np.inexact) for dtype in in_dtypes):
|
||||
raise TypeError("linear_transpose only supports float and complex inputs, "
|
||||
f"but got {in_dtypes}")
|
||||
|
||||
in_pvals = map(pe.PartialVal.unknown, in_avals)
|
||||
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
|
||||
instantiate=True)
|
||||
out_avals, _ = unzip2(out_pvals)
|
||||
out_dtypes = map(dtypes.dtype, out_avals)
|
||||
if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes)
|
||||
or all(dtypes.issubdtype(d, np.integer)
|
||||
for d in in_dtypes + out_dtypes)):
|
||||
raise TypeError("linear_transpose only supports [float or complex] -> "
|
||||
"[float or complex], and integer -> integer functions, "
|
||||
f"but got {in_dtypes} -> {out_dtypes}.")
|
||||
|
||||
def transposed_fun(out_cotangent):
|
||||
out_cotangents, out_tree2 = tree_flatten(out_cotangent)
|
||||
|
@ -1126,11 +1126,17 @@ class APITest(jtu.JaxTestCase):
|
||||
z, = transpose_fun(y)
|
||||
self.assertArraysEqual(2 * y, z, check_dtypes=True)
|
||||
|
||||
def test_linear_transpose_integer(self):
|
||||
f = lambda x: 2 * x
|
||||
transpose = api.linear_transpose(f, 1)
|
||||
actual, = transpose(3)
|
||||
expected = 6
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_linear_transpose_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "linear_transpose only supports float and complex inputs"):
|
||||
api.linear_transpose(lambda x: x, 1)
|
||||
|
||||
TypeError, "linear_transpose only supports"):
|
||||
api.linear_transpose(lambda x: 2. * x, 1)
|
||||
transpose_fun = api.linear_transpose(lambda x: [x, x], 1.0)
|
||||
with self.assertRaisesRegex(TypeError, "cotangent tree does not match"):
|
||||
transpose_fun(1.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user