Fix custom_linear_solve batching rule in case of auxiliary arguments

Previously, batching in-/out axes of the wrong lengths were passed into
the batched jaxpr builders for the `matvec` and `solve_t` jaxprs. This
commit is a best-effort fix from debugging the axes designations in
the batched jaxpr constructions of these functions.
This commit is contained in:
Nicholas Junge 2023-03-06 16:57:31 +01:00
parent 2ccd785e16
commit 89039ecc9c

View File

@ -387,7 +387,9 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
(matvec, vecmat, solve, solve_t) = jaxprs
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
num_aux = len(solve.out_avals) - len(matvec.out_avals)
# number of operator out avals is assumed to be the same for matvec/vecmat
num_operator_out_avals = len(matvec.out_avals)
num_aux = len(solve.out_avals) - num_operator_out_avals
# Fixpoint computation of which parts of x and b are batched; we need to
# ensure this is consistent between all four jaxprs
b_bat = orig_b_bat
@ -402,21 +404,23 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
x_bat_out = solve_x_bat
else:
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat,
vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
# batch all aux data by default
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
# keep a slice of only the linear operator part of solve's avals
x_bat_noaux = x_bat_out[:num_operator_out_avals]
# Apply matvec and solve_t -> new batched parts of b
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat,
matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
if solve_t is None:
solve_t_jaxpr_batched = None
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
else:
solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat,
solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out,
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])