Activate Householder Product to XLA's FFI

PiperOrigin-RevId: 670196460
This commit is contained in:
Paweł Paruzel 2024-09-02 06:18:13 -07:00 committed by jax authors
parent 8cb3596136
commit 414eb90f5b
4 changed files with 72 additions and 40 deletions

View File

@ -931,6 +931,7 @@ def _check_lowering(lowering) -> None:
_CPU_FFI_KERNELS = [
"lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi",
"lapack_sgeqrf_ffi", "lapack_dgeqrf_ffi", "lapack_cgeqrf_ffi", "lapack_zgeqrf_ffi",
"lapack_sorgqr_ffi", "lapack_dorgqr_ffi", "lapack_cungqr_ffi", "lapack_zungqr_ffi",
"lapack_ssyevd_ffi", "lapack_dsyevd_ffi", "lapack_cheevd_ffi", "lapack_zheevd_ffi",
"lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi",
"lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi",

View File

@ -357,7 +357,7 @@ data_2024_08_22 = {}
data_2024_08_22['c128'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr'],
custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'],
serialized_date=datetime.date(2024, 8, 22),
inputs=(),
expected_outputs=(
@ -479,7 +479,7 @@ module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_p
data_2024_08_22['c64'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr'],
custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'],
serialized_date=datetime.date(2024, 8, 22),
inputs=(),
expected_outputs=(
@ -595,7 +595,7 @@ module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_p
data_2024_08_22['f32'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr'],
custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'],
serialized_date=datetime.date(2024, 8, 22),
inputs=(),
expected_outputs=(
@ -703,7 +703,7 @@ module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_p
data_2024_08_22['f64'] = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr'],
custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'],
serialized_date=datetime.date(2024, 8, 22),
inputs=(),
expected_outputs=(

View File

@ -1759,11 +1759,21 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
else:
# TODO(b/344892332): Remove the conditional after the compatibility period
ctx_args = (
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
)
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus,
a_shape_vals=a_shape_vals,
tau_shape_vals=tau_shape_vals)
a, *maybe_info_orgqr = orgqr_impl(*ctx_args, a_aval.dtype, a, taus,
a_shape_vals=a_shape_vals,
tau_shape_vals=tau_shape_vals)
if not ctx.is_forward_compat():
# Skip the info parameter verification for the FFI kernel.
return [a]
# TODO(b/344892332): This parameter will no longer be needed after
# the forward compatibility period
info_orgqr = maybe_info_orgqr[0]
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED")
select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))

View File

@ -281,10 +281,11 @@ def geqrf_hlo(
# # ?orgqr: product of elementary Householder reflectors:
def orgqr_hlo(dtype, a: ir.Value, tau, *,
def orgqr_hlo(ctx, dtype, a: ir.Value, tau, *,
a_shape_vals: tuple[DimensionSize, ...],
tau_shape_vals: tuple[DimensionSize, ...]):
_lapack.initialize()
fn_base = "un" if dtype == np.complex64 or dtype == np.complex128 else "or"
fn_base = prepare_lapack_call(fn_base=fn_base + "gqr", dtype=dtype)
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
dims_vals = a_shape_vals
@ -294,55 +295,75 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *,
assert n != ir.ShapedType.get_dynamic_size()
batch_dims_vals = dims_vals[:-2]
num_bd = len(batch_dims_vals)
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
k = tau_shape_vals[-1]
assert type(k) is int
if dtype == np.float32:
fn = "lapack_sorgqr"
lwork = _lapack.lapack_sorgqr_workspace(m, n, k)
elif dtype == np.float64:
fn = "lapack_dorgqr"
lwork = _lapack.lapack_dorgqr_workspace(m, n, k)
elif dtype == np.complex64:
fn = "lapack_cungqr"
lwork = _lapack.lapack_cungqr_workspace(m, n, k)
elif dtype == np.complex128:
fn = "lapack_zungqr"
lwork = _lapack.lapack_zungqr_workspace(m, n, k)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
if ctx.is_forward_compat():
fn = fn_base
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
if dtype == np.float32:
lwork = _lapack.lapack_sorgqr_workspace(m, n, k)
elif dtype == np.float64:
lwork = _lapack.lapack_dorgqr_workspace(m, n, k)
elif dtype == np.complex64:
lwork = _lapack.lapack_cungqr_workspace(m, n, k)
elif dtype == np.complex128:
lwork = _lapack.lapack_zungqr_workspace(m, n, k)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
scalar_layout = []
shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_type.element_type),
(batch_dims_vals, i32_type),
([lwork], a_type.element_type),
]
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
return custom_call(
fn,
result_types=result_types,
operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k),
hlo_s32(lwork), a, tau],
operand_layouts=[scalar_layout] * 5 + [
layout,
tuple(range(num_bd, -1, -1)),
],
result_layouts=[
layout,
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={5: 0},
result_shapes=result_shapes,
).results[:2]
fn = fn_base + "_ffi"
shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_type.element_type),
(batch_dims_vals, i32_type),
([lwork], a_type.element_type),
]
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
out = custom_call(
return custom_call(
fn,
result_types=result_types,
operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k),
hlo_s32(lwork), a, tau],
operand_layouts=[scalar_layout] * 5 + [
operands=[
a, tau
],
operand_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
],
result_layouts=[
layout,
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={5: 0},
operand_output_aliases={0: 0},
result_shapes=result_shapes,
backend_config={},
api_version=4,
).results
return out[:2]
# ?potrf: Cholesky decomposition