[jax2tf] Implemented a conversion for linear_solve_p

This commit is contained in:
Benjamin Chetioui 2020-10-16 12:51:47 +02:00
parent f785794aca
commit a8cfc74871
4 changed files with 59 additions and 6 deletions

View File

@ -618,8 +618,6 @@ for unexpected in [
tf_not_yet_impl = [
lax.reduce_p, lax.rng_uniform_p,
lax.linear_solve_p,
lax.igamma_grad_a_p,
lax.random_gamma_grad_p,
@ -1818,6 +1816,12 @@ def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
tf_impl[lax_linalg.triangular_solve_p] = _triangular_solve
def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval):
return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl)(
*args, const_lengths=const_lengths, jaxprs=jaxprs, _in_avals=_in_avals, _out_aval=_out_aval)
tf_impl_with_avals[lax.linear_solve_p] = _linear_solve
def _custom_jvp_call_jaxpr(*args: TfVal,
fun_jaxpr: core.ClosedJaxpr,
jvp_jaxpr_thunk: Callable,

View File

@ -1,6 +1,6 @@
# Primitives with limited support
*Last generated on (YYYY-MM-DD): 2020-10-16*
*Last generated on (YYYY-MM-DD): 2020-10-19*
## Updating the documentation
@ -52,6 +52,7 @@ conversion to Tensorflow.
| erfc | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
| fft | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | complex128, float64 | CPU, GPU, TPU |
| lgamma | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
| lu | Missing TF support | Primitive is unimplemented in TF | complex64 | TPU |
| max | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |
| min | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |
| mul | Missing TF support | Primitive is unimplemented in TF | uint32, uint64 | CPU, GPU, TPU |
@ -78,16 +79,17 @@ conversion to Tensorflow.
| sort | Missing TF support | Primitive is unimplemented in TF; stable sort not implemented for XlaSort | ALL | CPU, GPU, TPU |
| svd | Missing TF support | Primitive is unimplemented in TF; this works on JAX because JAX uses a custom implementation | complex128, complex64 | CPU, GPU |
| top_k | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | float64, int64, uint64 | CPU, GPU, TPU |
| triangular_solve | Missing TF support | Primitive is unimplemented in TF | bfloat16, float16 | CPU, GPU, TPU |
## Not yet implemented primitive conversions
The conversion of the following JAX primitives is not yet implemented:
`after_all`, `all_to_all`, `axis_index`, `create_token`, `cummax`, `cummin`, `custom_linear_solve`, `igamma_grad_a`, `infeed`, `lu`, `outfeed`, `pmax`, `pmin`, `ppermute`, `psum`, `random_gamma_grad`, `reduce`, `rng_uniform`, `triangular_solve`, `xla_pmap`
`after_all`, `all_to_all`, `axis_index`, `create_token`, `cummax`, `cummin`, `igamma_grad_a`, `infeed`, `outfeed`, `pmax`, `pmin`, `ppermute`, `psum`, `random_gamma_grad`, `reduce`, `rng_uniform`, `xla_pmap`
## Primitive conversions with missing tests
The following JAX primitives have a defined conversion but are known to be
missing tests:
`argmax`, `argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `imag`, `integer_pow`, `real`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `stop_gradient`, `tie_in`
`argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `tie_in`

View File

@ -703,6 +703,41 @@ lax_linalg_triangular_solve = tuple( # Validate dtypes
# conjugate_a is irrelevant for real dtypes, and is thus omitted
)
def _make_linear_solve_harnesses():
def linear_solve(a, b, solve, transpose_solve=None, symmetric=False):
matvec = partial(lax.dot, a, precision=lax.Precision.HIGHEST)
return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric)
def explicit_jacobian_solve(matvec, b):
return lax.stop_gradient(jnp.linalg.solve(jax.api.jacobian(matvec)(b), b))
def _make_harness(name, *, shape=(4, 4), dtype=np.float32, symmetric=False,
solvers=(explicit_jacobian_solve, explicit_jacobian_solve)):
solve, transpose_solve = solvers
transpose_solve_name = transpose_solve.__name__ if transpose_solve else None
return Harness(f"_{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_b={jtu.format_shape_dtype_string(shape[:-1], dtype)}_solve={solve.__name__}_transposesolve={transpose_solve_name}_symmetric={symmetric}",
linear_solve,
[RandArg(shape, dtype), RandArg(shape[:-1], dtype),
StaticArg(solve), StaticArg(transpose_solve),
StaticArg(symmetric)],
shape=shape,
dtype=dtype,
solve=solve,
transpose_solve=transpose_solve,
symmetric=symmetric)
return tuple( # Validate dtypes
_make_harness("dtypes", dtype=dtype)
for dtype in
jtu.dtypes.all_floating if not dtype in [np.float16, dtypes.bfloat16]
) + tuple( # Validate symmetricity
[_make_harness("symmetric", symmetric=True)]
) + tuple( # Validate removing transpose_solve
[_make_harness("transpose_solve", solvers=(explicit_jacobian_solve, None))]
)
lax_linear_solve = _make_linear_solve_harnesses()
lax_slice = tuple(
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
lax.slice,
@ -917,7 +952,7 @@ lax_reduce_window = tuple( # Validate dtypes across all execution paths
(lax.max, 1), # path through reduce_window
] + ([
(lax.add, 0), # path_through reduce_window_sum
(lax.mul, 1), # path through reduce_window_mul
(lax.mul, 1), # path through reduce_window
] if dtype != jnp.bool_ else [])
) + tuple( # Validate window_dimensions
_make_reduce_window_harness("window_dimensions",

View File

@ -457,6 +457,18 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=atol, rtol=rtol)
@primitive_harness.parameterized(primitive_harness.lax_linear_solve)
def test_linear_solve(self, harness: primitive_harness.Harness):
a, b = harness.dyn_args_maker(self.rng())
if harness.params["symmetric"]:
a = a + a.T
tol = None
if (harness.params["dtype"] == np.float32 and
jtu.device_under_test() == "tpu"):
tol = 0.01
self.ConvertAndCompare(harness.dyn_fun, a, b, atol=tol, rtol=tol)
@primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
def test_unary_elementwise(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]