mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Implemented conversion of lu_p using _convert_jax_impl.
This commit is contained in:
parent
059cd71f33
commit
7d30187767
@ -618,7 +618,6 @@ tf_not_yet_impl = [
|
||||
lax.reduce_p, lax.rng_uniform_p,
|
||||
|
||||
lax.linear_solve_p,
|
||||
lax_linalg.lu_p,
|
||||
|
||||
lax.igamma_grad_a_p,
|
||||
lax.random_gamma_grad_p,
|
||||
@ -1789,6 +1788,12 @@ def _eigh(operand: TfVal, lower: bool):
|
||||
|
||||
tf_impl[lax_linalg.eigh_p] = _eigh
|
||||
|
||||
def _lu(operand: TfVal, _in_avals, _out_aval):
|
||||
return _convert_jax_impl(lax_linalg._lu_python)(operand, _in_avals=_in_avals,
|
||||
_out_aval=_out_aval)
|
||||
|
||||
tf_impl_with_avals[lax_linalg.lu_p] = _lu
|
||||
|
||||
def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
|
||||
transpose_a: bool, conjugate_a: bool,
|
||||
unit_diagonal: bool):
|
||||
|
@ -141,6 +141,10 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
|
||||
tf_unimpl(np_dtype, additional_msg=("this is a problem only in compiled "
|
||||
"mode (experimental_compile=True))"))
|
||||
|
||||
if prim is lax_linalg.lu_p:
|
||||
if np_dtype == np.complex64:
|
||||
tf_unimpl(np_dtype, devs=["TPU"])
|
||||
|
||||
if prim is lax_linalg.triangular_solve_p:
|
||||
if np_dtype in [dtypes.bfloat16, np.float16]:
|
||||
tf_unimpl(np_dtype)
|
||||
|
@ -636,6 +636,20 @@ lax_linalg_eigh = tuple(
|
||||
if dtype != np.float16
|
||||
)
|
||||
|
||||
lax_linalg_lu = tuple(
|
||||
Harness(f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
|
||||
lax_linalg.lu,
|
||||
[RandArg(shape, dtype)],
|
||||
shape=shape,
|
||||
dtype=dtype)
|
||||
for dtype in jtu.dtypes.all_inexact
|
||||
for shape in [
|
||||
(5, 5), # square
|
||||
(3, 5, 5), # batched
|
||||
(3, 5), # non-square
|
||||
]
|
||||
)
|
||||
|
||||
def _make_triangular_solve_harness(name, *, left_side=True, lower=False,
|
||||
ab_shapes=((4, 4), (4, 1)), dtype=np.float32,
|
||||
transpose_a=False, conjugate_a=False,
|
||||
|
@ -24,6 +24,7 @@ import itertools
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import lax_linalg
|
||||
from jax import numpy as jnp
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
@ -402,6 +403,47 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
custom_assert=custom_assert,
|
||||
always_custom_assert=always_custom_assert)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_linalg_lu)
|
||||
def test_lu(self, harness: primitive_harness.Harness):
|
||||
dtype = harness.params["dtype"]
|
||||
if dtype in [np.float16, dtypes.bfloat16]:
|
||||
raise unittest.SkipTest(
|
||||
f"LU is not implemented in JAX for dtype {dtype}.")
|
||||
tol = None
|
||||
if dtype in [np.float32, np.complex64]:
|
||||
if jtu.device_under_test() == "tpu":
|
||||
tol = 0.1
|
||||
else:
|
||||
tol = 1e-5
|
||||
if dtype in [np.float64, np.complex128]:
|
||||
tol = 1e-13
|
||||
operand, = harness.dyn_args_maker(self.rng())
|
||||
|
||||
def custom_assert(result_jax, result_tf):
|
||||
lu, pivots, perm = tuple(map(lambda t: t.numpy(), result_tf))
|
||||
batch_dims = operand.shape[:-2]
|
||||
m, n = operand.shape[-2], operand.shape[-1]
|
||||
def _make_permutation_matrix(perm):
|
||||
result = []
|
||||
for idx in itertools.product(*map(range, operand.shape[:-1])):
|
||||
result += [0 if c != perm[idx] else 1 for c in range(m)]
|
||||
result = np.reshape(np.array(result, dtype=dtype), [*batch_dims, m, m])
|
||||
return result
|
||||
|
||||
k = min(m, n)
|
||||
l = jnp.tril(lu, -1)[...,:, :k] + jnp.eye(m, k, dtype=dtype)
|
||||
u = jnp.triu(lu)[...,:k, :]
|
||||
p_mat = _make_permutation_matrix(perm)
|
||||
|
||||
self.assertArraysEqual(lax_linalg.lu_pivots_to_permutation(pivots, m),
|
||||
perm)
|
||||
self.assertAllClose(jnp.matmul(p_mat, operand), jnp.matmul(l, u),
|
||||
atol=tol, rtol=tol)
|
||||
|
||||
self.ConvertAndCompare(harness.dyn_fun, operand, atol=tol, rtol=tol,
|
||||
custom_assert=custom_assert,
|
||||
always_custom_assert=True)
|
||||
|
||||
@primitive_harness.parameterized(
|
||||
primitive_harness.lax_linalg_triangular_solve)
|
||||
def test_triangular_solve(self, harness: primitive_harness.Harness):
|
||||
|
Loading…
x
Reference in New Issue
Block a user