in partial_eval_custom rule for pjit, cache ClosedJaxpr creation

Anywhere we call the ClosedJaxpr constructor, we had better be under a cache.
We should audit the code...

Never trust comments, especially when blame says mattjj wrote them

Co-authored-by: Yash Katariya <yashkatariya@google.com>
This commit is contained in:
Matthew Johnson 2023-12-21 15:31:28 -08:00
parent 14fe47c5b7
commit e31018b9c5
2 changed files with 62 additions and 20 deletions

View File

@ -846,6 +846,37 @@ def bench_make_array_from_callback_fully_replicated_sharding(state):
while state:
jax.make_array_from_callback(shape, s, np_arr.__getitem__)
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def benchmark_lorentz63_cache_hits(state):
@jax.jit
def lorentz63(state, dt=0.01, sigma=10, beta=8/3, rho=28):
x, y, z = state
x_t = sigma * (y - x)
y_t = (rho - z) * x - y
z_t = x * y - beta * z
return jnp.array([x + x_t * dt, y + y_t * dt, z + z_t * dt])
def training_step(initial_conditions, steps=1, unroll=False):
def forward_sim(x0):
if unroll:
x = x0
for _ in range(steps):
x = lorentz63(x)
return x
else:
return jax.lax.fori_loop(0, steps, lambda _, x: lorentz63(x), x0)
def loss(x0):
out = jax.vmap(jax.remat(forward_sim))(x0)
return jnp.square(out).sum()
return jax.value_and_grad(loss)(initial_conditions)
x = jnp.ones((8, 3))
while state:
jax.make_jaxpr(lambda x: training_step(x, 100, unroll=True))(x)
if __name__ == "__main__":
google_benchmark.main()

View File

@ -1399,28 +1399,12 @@ def closed_call_partial_eval_custom_rule(
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
closed_jaxpr = eqn.params[jaxpr_param_name]
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_val, num_res_ref = \
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
False, False, saveable)
dropvars = tuple(isinstance(v, DropVar) for v in eqn.outvars)
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res_ref, num_res_val, out_fwd = \
_closed_jaxpr_partial_eval_custom_cached(
eqn.params[jaxpr_param_name], (*unks_in,), (*inst_in,), dropvars, saveable)
num_res = num_res_ref + num_res_val
# Compute which residual value outputs are also *undropped* primal outputs.
num_out_primals = len(jaxpr_known_.outvars) - num_res_val
out_vars, res_vars = split_list(jaxpr_known_.outvars, [num_out_primals])
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known))
if type(b) is not DropVar}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
# Prune jaxpr_known_ outputs by removing forwards.
jaxpr_known_ = prune_jaxpr_outputs(
jaxpr_known_, [True] * num_out_primals + [f is None for f in out_fwd])
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
ins_known, _ = partition_list(unks_in, eqn.invars)
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
@ -1452,6 +1436,33 @@ def closed_call_partial_eval_custom_rule(
new_vars = [*new_inst, *res_val_vars, *res_ref_binders]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
@weakref_lru_cache
def _closed_jaxpr_partial_eval_custom_cached(
jaxpr: ClosedJaxpr, unks_in: tuple[bool, ...], inst_in: tuple[bool, ...],
dropvars: tuple[bool, ...], saveable: Callable
) -> tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool], Sequence[bool],
int, int, Sequence[int | None]]:
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_val, num_res_ref = \
partial_eval_jaxpr_stateful(jaxpr.jaxpr, unks_in, inst_in,
False, False, saveable)
# Compute which residual value outputs are also *undropped* primal outputs.
num_out_primals = len(jaxpr_known_.outvars) - num_res_val
out_vars, res_vars = split_list(jaxpr_known_.outvars, [num_out_primals])
out_dropvars_known, _ = partition_list(unks_out, dropvars)
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_dropvars_known))
if not b}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
# Prune jaxpr_known_ outputs by removing forwards.
jaxpr_known_ = prune_jaxpr_outputs(
jaxpr_known_, [True] * num_out_primals + [f is None for f in out_fwd])
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, jaxpr.consts)
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
return jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res_ref, num_res_val, out_fwd
partial_eval_jaxpr_custom_rules[core.call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, _____, x, y: (x, y))