mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
in partial_eval_custom rule for pjit, cache ClosedJaxpr creation
Anywhere we call the ClosedJaxpr constructor, we had better be under a cache. We should audit the code... Never trust comments, especially when blame says mattjj wrote them Co-authored-by: Yash Katariya <yashkatariya@google.com>
This commit is contained in:
parent
14fe47c5b7
commit
e31018b9c5
@ -846,6 +846,37 @@ def bench_make_array_from_callback_fully_replicated_sharding(state):
|
||||
while state:
|
||||
jax.make_array_from_callback(shape, s, np_arr.__getitem__)
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMillisecond)
|
||||
def benchmark_lorentz63_cache_hits(state):
|
||||
@jax.jit
|
||||
def lorentz63(state, dt=0.01, sigma=10, beta=8/3, rho=28):
|
||||
x, y, z = state
|
||||
x_t = sigma * (y - x)
|
||||
y_t = (rho - z) * x - y
|
||||
z_t = x * y - beta * z
|
||||
return jnp.array([x + x_t * dt, y + y_t * dt, z + z_t * dt])
|
||||
|
||||
def training_step(initial_conditions, steps=1, unroll=False):
|
||||
def forward_sim(x0):
|
||||
if unroll:
|
||||
x = x0
|
||||
for _ in range(steps):
|
||||
x = lorentz63(x)
|
||||
return x
|
||||
else:
|
||||
return jax.lax.fori_loop(0, steps, lambda _, x: lorentz63(x), x0)
|
||||
|
||||
def loss(x0):
|
||||
out = jax.vmap(jax.remat(forward_sim))(x0)
|
||||
return jnp.square(out).sum()
|
||||
|
||||
return jax.value_and_grad(loss)(initial_conditions)
|
||||
|
||||
x = jnp.ones((8, 3))
|
||||
while state:
|
||||
jax.make_jaxpr(lambda x: training_step(x, 100, unroll=True))(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
@ -1399,28 +1399,12 @@ def closed_call_partial_eval_custom_rule(
|
||||
eqn: JaxprEqn, *, res_aval: ResAvalUpdater = _default_res_aval_updater,
|
||||
) -> tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], list[Var]]:
|
||||
# TODO(sharadmv,mattjj): dedup this rule with call_partial_eval_custom_rule.
|
||||
closed_jaxpr = eqn.params[jaxpr_param_name]
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_val, num_res_ref = \
|
||||
partial_eval_jaxpr_stateful(closed_jaxpr.jaxpr, unks_in, inst_in,
|
||||
False, False, saveable)
|
||||
dropvars = tuple(isinstance(v, DropVar) for v in eqn.outvars)
|
||||
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res_ref, num_res_val, out_fwd = \
|
||||
_closed_jaxpr_partial_eval_custom_cached(
|
||||
eqn.params[jaxpr_param_name], (*unks_in,), (*inst_in,), dropvars, saveable)
|
||||
num_res = num_res_ref + num_res_val
|
||||
|
||||
# Compute which residual value outputs are also *undropped* primal outputs.
|
||||
num_out_primals = len(jaxpr_known_.outvars) - num_res_val
|
||||
out_vars, res_vars = split_list(jaxpr_known_.outvars, [num_out_primals])
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known))
|
||||
if type(b) is not DropVar}
|
||||
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||
|
||||
# Prune jaxpr_known_ outputs by removing forwards.
|
||||
jaxpr_known_ = prune_jaxpr_outputs(
|
||||
jaxpr_known_, [True] * num_out_primals + [f is None for f in out_fwd])
|
||||
|
||||
# Forming these fresh ClosedJaxprs defeats caching, but caller handles caching
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, closed_jaxpr.consts)
|
||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, closed_jaxpr.consts)
|
||||
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
||||
@ -1452,6 +1436,33 @@ def closed_call_partial_eval_custom_rule(
|
||||
new_vars = [*new_inst, *res_val_vars, *res_ref_binders]
|
||||
return eqn_known, eqn_staged, unks_out, inst_out, new_vars
|
||||
|
||||
@weakref_lru_cache
|
||||
def _closed_jaxpr_partial_eval_custom_cached(
|
||||
jaxpr: ClosedJaxpr, unks_in: tuple[bool, ...], inst_in: tuple[bool, ...],
|
||||
dropvars: tuple[bool, ...], saveable: Callable
|
||||
) -> tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool], Sequence[bool],
|
||||
int, int, Sequence[int | None]]:
|
||||
jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res_val, num_res_ref = \
|
||||
partial_eval_jaxpr_stateful(jaxpr.jaxpr, unks_in, inst_in,
|
||||
False, False, saveable)
|
||||
|
||||
# Compute which residual value outputs are also *undropped* primal outputs.
|
||||
num_out_primals = len(jaxpr_known_.outvars) - num_res_val
|
||||
out_vars, res_vars = split_list(jaxpr_known_.outvars, [num_out_primals])
|
||||
out_dropvars_known, _ = partition_list(unks_out, dropvars)
|
||||
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_dropvars_known))
|
||||
if not b}
|
||||
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
||||
|
||||
# Prune jaxpr_known_ outputs by removing forwards.
|
||||
jaxpr_known_ = prune_jaxpr_outputs(
|
||||
jaxpr_known_, [True] * num_out_primals + [f is None for f in out_fwd])
|
||||
|
||||
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, jaxpr.consts)
|
||||
jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts)
|
||||
return jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res_ref, num_res_val, out_fwd
|
||||
|
||||
|
||||
partial_eval_jaxpr_custom_rules[core.call_p] = \
|
||||
partial(call_partial_eval_custom_rule, 'call_jaxpr',
|
||||
lambda _, __, ___, ____, _____, x, y: (x, y))
|
||||
|
Loading…
x
Reference in New Issue
Block a user