mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Activate Householder Product to XLA's FFI
PiperOrigin-RevId: 670196460
This commit is contained in:
parent
8cb3596136
commit
414eb90f5b
@ -931,6 +931,7 @@ def _check_lowering(lowering) -> None:
|
||||
_CPU_FFI_KERNELS = [
|
||||
"lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi",
|
||||
"lapack_sgeqrf_ffi", "lapack_dgeqrf_ffi", "lapack_cgeqrf_ffi", "lapack_zgeqrf_ffi",
|
||||
"lapack_sorgqr_ffi", "lapack_dorgqr_ffi", "lapack_cungqr_ffi", "lapack_zungqr_ffi",
|
||||
"lapack_ssyevd_ffi", "lapack_dsyevd_ffi", "lapack_cheevd_ffi", "lapack_zheevd_ffi",
|
||||
"lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi",
|
||||
"lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi",
|
||||
|
@ -357,7 +357,7 @@ data_2024_08_22 = {}
|
||||
data_2024_08_22['c128'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr'],
|
||||
custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'],
|
||||
serialized_date=datetime.date(2024, 8, 22),
|
||||
inputs=(),
|
||||
expected_outputs=(
|
||||
@ -479,7 +479,7 @@ module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_p
|
||||
data_2024_08_22['c64'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr'],
|
||||
custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'],
|
||||
serialized_date=datetime.date(2024, 8, 22),
|
||||
inputs=(),
|
||||
expected_outputs=(
|
||||
@ -595,7 +595,7 @@ module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_p
|
||||
data_2024_08_22['f32'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr'],
|
||||
custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'],
|
||||
serialized_date=datetime.date(2024, 8, 22),
|
||||
inputs=(),
|
||||
expected_outputs=(
|
||||
@ -703,7 +703,7 @@ module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_p
|
||||
data_2024_08_22['f64'] = dict(
|
||||
testdata_version=1,
|
||||
platform='cpu',
|
||||
custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr'],
|
||||
custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'],
|
||||
serialized_date=datetime.date(2024, 8, 22),
|
||||
inputs=(),
|
||||
expected_outputs=(
|
||||
|
@ -1759,11 +1759,21 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *,
|
||||
f"on GPU is not implemented; b/261671778; {a_aval.shape}")
|
||||
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
|
||||
else:
|
||||
# TODO(b/344892332): Remove the conditional after the compatibility period
|
||||
ctx_args = (
|
||||
(ctx,) if platform == "cpu" and jaxlib_version >= (0, 4, 32) else ()
|
||||
)
|
||||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
|
||||
tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape)
|
||||
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus,
|
||||
a_shape_vals=a_shape_vals,
|
||||
tau_shape_vals=tau_shape_vals)
|
||||
a, *maybe_info_orgqr = orgqr_impl(*ctx_args, a_aval.dtype, a, taus,
|
||||
a_shape_vals=a_shape_vals,
|
||||
tau_shape_vals=tau_shape_vals)
|
||||
if not ctx.is_forward_compat():
|
||||
# Skip the info parameter verification for the FFI kernel.
|
||||
return [a]
|
||||
# TODO(b/344892332): This parameter will no longer be needed after
|
||||
# the forward compatibility period
|
||||
info_orgqr = maybe_info_orgqr[0]
|
||||
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED")
|
||||
select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
|
||||
|
@ -281,10 +281,11 @@ def geqrf_hlo(
|
||||
|
||||
|
||||
# # ?orgqr: product of elementary Householder reflectors:
|
||||
def orgqr_hlo(dtype, a: ir.Value, tau, *,
|
||||
def orgqr_hlo(ctx, dtype, a: ir.Value, tau, *,
|
||||
a_shape_vals: tuple[DimensionSize, ...],
|
||||
tau_shape_vals: tuple[DimensionSize, ...]):
|
||||
_lapack.initialize()
|
||||
fn_base = "un" if dtype == np.complex64 or dtype == np.complex128 else "or"
|
||||
fn_base = prepare_lapack_call(fn_base=fn_base + "gqr", dtype=dtype)
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
dims_vals = a_shape_vals
|
||||
@ -294,55 +295,75 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *,
|
||||
assert n != ir.ShapedType.get_dynamic_size()
|
||||
batch_dims_vals = dims_vals[:-2]
|
||||
num_bd = len(batch_dims_vals)
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
k = tau_shape_vals[-1]
|
||||
assert type(k) is int
|
||||
|
||||
if dtype == np.float32:
|
||||
fn = "lapack_sorgqr"
|
||||
lwork = _lapack.lapack_sorgqr_workspace(m, n, k)
|
||||
elif dtype == np.float64:
|
||||
fn = "lapack_dorgqr"
|
||||
lwork = _lapack.lapack_dorgqr_workspace(m, n, k)
|
||||
elif dtype == np.complex64:
|
||||
fn = "lapack_cungqr"
|
||||
lwork = _lapack.lapack_cungqr_workspace(m, n, k)
|
||||
elif dtype == np.complex128:
|
||||
fn = "lapack_zungqr"
|
||||
lwork = _lapack.lapack_zungqr_workspace(m, n, k)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
scalar_layout = []
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
|
||||
if ctx.is_forward_compat():
|
||||
fn = fn_base
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
if dtype == np.float32:
|
||||
lwork = _lapack.lapack_sorgqr_workspace(m, n, k)
|
||||
elif dtype == np.float64:
|
||||
lwork = _lapack.lapack_dorgqr_workspace(m, n, k)
|
||||
elif dtype == np.complex64:
|
||||
lwork = _lapack.lapack_cungqr_workspace(m, n, k)
|
||||
elif dtype == np.complex128:
|
||||
lwork = _lapack.lapack_zungqr_workspace(m, n, k)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
scalar_layout = []
|
||||
shape_type_pairs: Sequence[ShapeTypePair] = [
|
||||
(a_shape_vals, a_type.element_type),
|
||||
(batch_dims_vals, i32_type),
|
||||
([lwork], a_type.element_type),
|
||||
]
|
||||
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
|
||||
return custom_call(
|
||||
fn,
|
||||
result_types=result_types,
|
||||
operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k),
|
||||
hlo_s32(lwork), a, tau],
|
||||
operand_layouts=[scalar_layout] * 5 + [
|
||||
layout,
|
||||
tuple(range(num_bd, -1, -1)),
|
||||
],
|
||||
result_layouts=[
|
||||
layout,
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={5: 0},
|
||||
result_shapes=result_shapes,
|
||||
).results[:2]
|
||||
fn = fn_base + "_ffi"
|
||||
shape_type_pairs: Sequence[ShapeTypePair] = [
|
||||
(a_shape_vals, a_type.element_type),
|
||||
(batch_dims_vals, i32_type),
|
||||
([lwork], a_type.element_type),
|
||||
]
|
||||
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
|
||||
out = custom_call(
|
||||
return custom_call(
|
||||
fn,
|
||||
result_types=result_types,
|
||||
operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k),
|
||||
hlo_s32(lwork), a, tau],
|
||||
operand_layouts=[scalar_layout] * 5 + [
|
||||
operands=[
|
||||
a, tau
|
||||
],
|
||||
operand_layouts=[
|
||||
layout,
|
||||
tuple(range(num_bd, -1, -1)),
|
||||
],
|
||||
result_layouts=[
|
||||
layout,
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={5: 0},
|
||||
operand_output_aliases={0: 0},
|
||||
result_shapes=result_shapes,
|
||||
backend_config={},
|
||||
api_version=4,
|
||||
).results
|
||||
return out[:2]
|
||||
|
||||
|
||||
# ?potrf: Cholesky decomposition
|
||||
|
Loading…
x
Reference in New Issue
Block a user