[jax2tf] Implemented conversion of lu_p using _convert_jax_impl.

This commit is contained in:
Benjamin Chetioui 2020-10-16 15:22:37 +02:00
parent 059cd71f33
commit 7d30187767
4 changed files with 66 additions and 1 deletions

View File

@ -618,7 +618,6 @@ tf_not_yet_impl = [
lax.reduce_p, lax.rng_uniform_p,
lax.linear_solve_p,
lax_linalg.lu_p,
lax.igamma_grad_a_p,
lax.random_gamma_grad_p,
@ -1789,6 +1788,12 @@ def _eigh(operand: TfVal, lower: bool):
tf_impl[lax_linalg.eigh_p] = _eigh
def _lu(operand: TfVal, _in_avals, _out_aval):
return _convert_jax_impl(lax_linalg._lu_python)(operand, _in_avals=_in_avals,
_out_aval=_out_aval)
tf_impl_with_avals[lax_linalg.lu_p] = _lu
def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
transpose_a: bool, conjugate_a: bool,
unit_diagonal: bool):

View File

@ -141,6 +141,10 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
"mode (experimental_compile=True))"))
if prim is lax_linalg.lu_p:
if np_dtype == np.complex64:
tf_unimpl(np_dtype, devs=["TPU"])
if prim is lax_linalg.triangular_solve_p:
if np_dtype in [dtypes.bfloat16, np.float16]:
tf_unimpl(np_dtype)

View File

@ -636,6 +636,20 @@ lax_linalg_eigh = tuple(
if dtype != np.float16
)
lax_linalg_lu = tuple(
Harness(f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
lax_linalg.lu,
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype)
for dtype in jtu.dtypes.all_inexact
for shape in [
(5, 5), # square
(3, 5, 5), # batched
(3, 5), # non-square
]
)
def _make_triangular_solve_harness(name, *, left_side=True, lower=False,
ab_shapes=((4, 4), (4, 1)), dtype=np.float32,
transpose_a=False, conjugate_a=False,

View File

@ -24,6 +24,7 @@ import itertools
import jax
from jax import dtypes
from jax import lax
from jax import lax_linalg
from jax import numpy as jnp
from jax import test_util as jtu
from jax.config import config
@ -402,6 +403,47 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
custom_assert=custom_assert,
always_custom_assert=always_custom_assert)
@primitive_harness.parameterized(primitive_harness.lax_linalg_lu)
def test_lu(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]
if dtype in [np.float16, dtypes.bfloat16]:
raise unittest.SkipTest(
f"LU is not implemented in JAX for dtype {dtype}.")
tol = None
if dtype in [np.float32, np.complex64]:
if jtu.device_under_test() == "tpu":
tol = 0.1
else:
tol = 1e-5
if dtype in [np.float64, np.complex128]:
tol = 1e-13
operand, = harness.dyn_args_maker(self.rng())
def custom_assert(result_jax, result_tf):
lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf))
batch_dims = operand.shape[:-2]
m, n = operand.shape[-2], operand.shape[-1]
def _make_permutation_matrix(perm):
result = []
for idx in itertools.product(*map(range, operand.shape[:-1])):
result += [0 if c != perm[idx] else 1 for c in range(m)]
result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m])
return result
k = min(m, n)
l = jnp.tril(lu, -1)[...,:, :k] + jnp.eye(m, k, dtype=dtype)
u = jnp.triu(lu)[...,:k, :]
p_mat = _make_permutation_matrix(perm)
self.assertArraysEqual(lax_linalg.lu_pivots_to_permutation(pivots, m),
perm)
self.assertAllClose(jnp.matmul(p_mat, operand), jnp.matmul(l, u),
atol=tol, rtol=tol)
self.ConvertAndCompare(harness.dyn_fun, operand, atol=tol, rtol=tol,
custom_assert=custom_assert,
always_custom_assert=True)
@primitive_harness.parameterized(
primitive_harness.lax_linalg_triangular_solve)
def test_triangular_solve(self, harness: primitive_harness.Harness):