mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:16:05 +00:00
Fix custom_linear_solve
batching rule in case of auxiliary arguments
Previously, batching in-/out axes of the wrong lengths were passed into the batched jaxpr builders for the `matvec` and `solve_t` jaxprs. This commit is a best-effort fix from debugging the axes designations in the batched jaxpr constructions of these functions.
This commit is contained in:
parent
2ccd785e16
commit
89039ecc9c
@ -387,7 +387,9 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
|
||||
(matvec, vecmat, solve, solve_t) = jaxprs
|
||||
(matvec_bat, vecmat_bat, solve_bat, solve_t_bat) = params_bat
|
||||
|
||||
num_aux = len(solve.out_avals) - len(matvec.out_avals)
|
||||
# number of operator out avals is assumed to be the same for matvec/vecmat
|
||||
num_operator_out_avals = len(matvec.out_avals)
|
||||
num_aux = len(solve.out_avals) - num_operator_out_avals
|
||||
# Fixpoint computation of which parts of x and b are batched; we need to
|
||||
# ensure this is consistent between all four jaxprs
|
||||
b_bat = orig_b_bat
|
||||
@ -402,21 +404,23 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type,
|
||||
x_bat_out = solve_x_bat
|
||||
else:
|
||||
vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr(
|
||||
vecmat, axis_size, vecmat_bat + b_bat, instantiate=x_bat,
|
||||
vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat,
|
||||
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
|
||||
# batch all aux data by default
|
||||
x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat)
|
||||
# keep a slice of only the linear operator part of solve's avals
|
||||
x_bat_noaux = x_bat_out[:num_operator_out_avals]
|
||||
|
||||
# Apply matvec and solve_t -> new batched parts of b
|
||||
matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr(
|
||||
matvec, axis_size, matvec_bat + x_bat_out, instantiate=b_bat,
|
||||
matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat,
|
||||
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
|
||||
if solve_t is None:
|
||||
solve_t_jaxpr_batched = None
|
||||
b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat)
|
||||
else:
|
||||
solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr(
|
||||
solve_t, axis_size, solve_t_bat + x_bat_out, instantiate=b_bat,
|
||||
solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out,
|
||||
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
|
||||
assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux
|
||||
solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)])
|
||||
|
Loading…
x
Reference in New Issue
Block a user