mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Implemented a conversion for linear_solve_p
This commit is contained in:
parent
f785794aca
commit
a8cfc74871
@ -618,8 +618,6 @@ for unexpected in [
|
||||
tf_not_yet_impl = [
|
||||
lax.reduce_p, lax.rng_uniform_p,
|
||||
|
||||
lax.linear_solve_p,
|
||||
|
||||
lax.igamma_grad_a_p,
|
||||
lax.random_gamma_grad_p,
|
||||
|
||||
@ -1818,6 +1816,12 @@ def _triangular_solve(a: TfVal, b: TfVal, *, left_side: bool, lower: bool,
|
||||
|
||||
tf_impl[lax_linalg.triangular_solve_p] = _triangular_solve
|
||||
|
||||
def _linear_solve(*args: TfVal, const_lengths, jaxprs, _in_avals, _out_aval):
|
||||
return _convert_jax_impl(lax_control_flow._custom_linear_solve_impl)(
|
||||
*args, const_lengths=const_lengths, jaxprs=jaxprs, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
|
||||
tf_impl_with_avals[lax.linear_solve_p] = _linear_solve
|
||||
|
||||
def _custom_jvp_call_jaxpr(*args: TfVal,
|
||||
fun_jaxpr: core.ClosedJaxpr,
|
||||
jvp_jaxpr_thunk: Callable,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Primitives with limited support
|
||||
|
||||
*Last generated on (YYYY-MM-DD): 2020-10-16*
|
||||
*Last generated on (YYYY-MM-DD): 2020-10-19*
|
||||
|
||||
## Updating the documentation
|
||||
|
||||
@ -52,6 +52,7 @@ conversion to Tensorflow.
|
||||
| erfc | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
|
||||
| fft | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | complex128, float64 | CPU, GPU, TPU |
|
||||
| lgamma | Missing TF support | Primitive is unimplemented in TF | bfloat16 | CPU, GPU |
|
||||
| lu | Missing TF support | Primitive is unimplemented in TF | complex64 | TPU |
|
||||
| max | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| min | Missing TF support | Primitive is unimplemented in TF | bool, complex128, complex64, int8, uint16, uint32, uint64 | CPU, GPU, TPU |
|
||||
| mul | Missing TF support | Primitive is unimplemented in TF | uint32, uint64 | CPU, GPU, TPU |
|
||||
@ -78,16 +79,17 @@ conversion to Tensorflow.
|
||||
| sort | Missing TF support | Primitive is unimplemented in TF; stable sort not implemented for XlaSort | ALL | CPU, GPU, TPU |
|
||||
| svd | Missing TF support | Primitive is unimplemented in TF; this works on JAX because JAX uses a custom implementation | complex128, complex64 | CPU, GPU |
|
||||
| top_k | Missing TF support | Primitive is unimplemented in TF; this is a problem only in compiled mode (experimental_compile=True)) | float64, int64, uint64 | CPU, GPU, TPU |
|
||||
| triangular_solve | Missing TF support | Primitive is unimplemented in TF | bfloat16, float16 | CPU, GPU, TPU |
|
||||
|
||||
## Not yet implemented primitive conversions
|
||||
|
||||
The conversion of the following JAX primitives is not yet implemented:
|
||||
|
||||
`after_all`, `all_to_all`, `axis_index`, `create_token`, `cummax`, `cummin`, `custom_linear_solve`, `igamma_grad_a`, `infeed`, `lu`, `outfeed`, `pmax`, `pmin`, `ppermute`, `psum`, `random_gamma_grad`, `reduce`, `rng_uniform`, `triangular_solve`, `xla_pmap`
|
||||
`after_all`, `all_to_all`, `axis_index`, `create_token`, `cummax`, `cummin`, `igamma_grad_a`, `infeed`, `outfeed`, `pmax`, `pmin`, `ppermute`, `psum`, `random_gamma_grad`, `reduce`, `rng_uniform`, `xla_pmap`
|
||||
|
||||
## Primitive conversions with missing tests
|
||||
|
||||
The following JAX primitives have a defined conversion but are known to be
|
||||
missing tests:
|
||||
|
||||
`argmax`, `argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `imag`, `integer_pow`, `real`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `stop_gradient`, `tie_in`
|
||||
`argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `tie_in`
|
||||
|
@ -703,6 +703,41 @@ lax_linalg_triangular_solve = tuple( # Validate dtypes
|
||||
# conjugate_a is irrelevant for real dtypes, and is thus omitted
|
||||
)
|
||||
|
||||
def _make_linear_solve_harnesses():
|
||||
def linear_solve(a, b, solve, transpose_solve=None, symmetric=False):
|
||||
matvec = partial(lax.dot, a, precision=lax.Precision.HIGHEST)
|
||||
return lax.custom_linear_solve(matvec, b, solve, transpose_solve, symmetric)
|
||||
|
||||
def explicit_jacobian_solve(matvec, b):
|
||||
return lax.stop_gradient(jnp.linalg.solve(jax.api.jacobian(matvec)(b), b))
|
||||
|
||||
def _make_harness(name, *, shape=(4, 4), dtype=np.float32, symmetric=False,
|
||||
solvers=(explicit_jacobian_solve, explicit_jacobian_solve)):
|
||||
solve, transpose_solve = solvers
|
||||
transpose_solve_name = transpose_solve.__name__ if transpose_solve else None
|
||||
return Harness(f"_{name}_a={jtu.format_shape_dtype_string(shape, dtype)}_b={jtu.format_shape_dtype_string(shape[:-1], dtype)}_solve={solve.__name__}_transposesolve={transpose_solve_name}_symmetric={symmetric}",
|
||||
linear_solve,
|
||||
[RandArg(shape, dtype), RandArg(shape[:-1], dtype),
|
||||
StaticArg(solve), StaticArg(transpose_solve),
|
||||
StaticArg(symmetric)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
solve=solve,
|
||||
transpose_solve=transpose_solve,
|
||||
symmetric=symmetric)
|
||||
|
||||
return tuple( # Validate dtypes
|
||||
_make_harness("dtypes", dtype=dtype)
|
||||
for dtype in
|
||||
jtu.dtypes.all_floating if not dtype in [np.float16, dtypes.bfloat16]
|
||||
) + tuple( # Validate symmetricity
|
||||
[_make_harness("symmetric", symmetric=True)]
|
||||
) + tuple( # Validate removing transpose_solve
|
||||
[_make_harness("transpose_solve", solvers=(explicit_jacobian_solve, None))]
|
||||
)
|
||||
|
||||
lax_linear_solve = _make_linear_solve_harnesses()
|
||||
|
||||
lax_slice = tuple(
|
||||
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
|
||||
lax.slice,
|
||||
@ -917,7 +952,7 @@ lax_reduce_window = tuple( # Validate dtypes across all execution paths
|
||||
(lax.max, 1), # path through reduce_window
|
||||
] + ([
|
||||
(lax.add, 0), # path_through reduce_window_sum
|
||||
(lax.mul, 1), # path through reduce_window_mul
|
||||
(lax.mul, 1), # path through reduce_window
|
||||
] if dtype != jnp.bool_ else [])
|
||||
) + tuple( # Validate window_dimensions
|
||||
_make_reduce_window_harness("window_dimensions",
|
||||
|
@ -457,6 +457,18 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
atol=atol, rtol=rtol)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_linear_solve)
|
||||
def test_linear_solve(self, harness: primitive_harness.Harness):
|
||||
a, b = harness.dyn_args_maker(self.rng())
|
||||
if harness.params["symmetric"]:
|
||||
a = a + a.T
|
||||
tol = None
|
||||
if (harness.params["dtype"] == np.float32 and
|
||||
jtu.device_under_test() == "tpu"):
|
||||
tol = 0.01
|
||||
|
||||
self.ConvertAndCompare(harness.dyn_fun, a, b, atol=tol, rtol=tol)
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
|
||||
def test_unary_elementwise(self, harness: primitive_harness.Harness):
|
||||
dtype = harness.params["dtype"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user