improve partial_eval_jaxpr_custom

* add caching via weakref_lru_cache
* add inst_in argument (needed for fixedpoints for loop primitives, in
  follow-up PR), update callers not to over-instantiate inputs (previously I
  had used a convention where call primitives would just stage out eqns with
  all inputs instantiated, for expediene)
* add ensure_out_unknowns and ensure_out_inst arguments, analogues of
  `instantiate` on e.g. partial_eval_jaxpr, jvp_jaxpr, etc (also neede for
 fixpoints of loop primitives)
* better dce in remat_partial_eval (e.g. prune unused residuals)
This commit is contained in:
Matthew Johnson 2022-05-06 21:55:04 -07:00
parent 705e241409
commit 7e241b682d
5 changed files with 145 additions and 79 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
import operator as op
from typing import Callable, Optional, List, Tuple
import types
@ -29,7 +30,7 @@ from jax._src import source_info_util
from jax._src.api_util import flatten_fun
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip)
safe_zip, merge_lists)
source_info_util.register_exclusion(__file__)
@ -302,43 +303,46 @@ ad.primitive_jvps[remat_p] = remat_jvp
def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
policy = params['policy'] or (lambda *_, **__: False)
# unzip into jaxpr_known and jaxpr_unknown
policy = params['policy'] or nothing_saveable
in_unknowns = [not t.is_known() for t in tracers]
# TODO(mattjj): use cached version of pe.partial_eval_jaxpr_custom
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)
# DCE jaxpr_staged, keeping only instantiated outputs which are unknown
_, out_inst_unknown = partition_list(out_inst, out_unknowns)
jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged, out_inst_unknown)
used_res, in_used_staged = split_list(in_used_staged, [num_res])
# DCE jaxpr_known, keeping all known outputs but discarding dce'd res
out_used_known = [True] * (len(out_unknowns) - sum(out_unknowns)) + used_res
jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
num_res = sum(used_res)
# compute known outputs and residuals (hoisted out of remat primitive)
_, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
_, in_consts = partition_list(in_used_known, in_consts_)
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
out_consts_ = iter(out_consts)
# form known outputs and collect residual tracers
out_known_tracers = [
pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
for uk in out_unknowns if not uk]
residuals = list(out_consts_)
out_knowns, residuals = split_list(out_consts, [len(out_consts)-num_res])
# set up unknown outputs with a recipe to call remat
res_tracers = map(trace.new_instantiated_const, residuals)
in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
_, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
_, tracers_staged = partition_list(in_used_staged, tracers)
in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged)
out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
for x in jaxpr_unknown.outvars]
new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
new_params, jaxpr_unknown.effects, source_info_util.current())
new_params, jaxpr_unknown.effects,
source_info_util.current())
for t in out_jaxpr_tracers: t.recipe = recipe
# zip together known and unknown outputs
return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
pe.custom_partial_eval_rules[remat_p] = remat_partial_eval
def remat_partial_eval_custom_params_updater(_, __, ___, ____, params_known, params_staged):
def remat_partial_eval_custom_params_updater(*args):
*_, params_known, params_staged = args
return params_known, dict(params_staged, differentiated=True)
pe.partial_eval_jaxpr_custom_rules[remat_p] = \
partial(pe.call_partial_eval_custom_rule, 'jaxpr',

View File

@ -1057,12 +1057,12 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
def _xmap_partial_eval_custom_params_updater(
unks_in: Sequence[bool],
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
assert params_known['spmd_in_axes'] is None and params_known['spmd_out_axes'] is None
assert params_staged['spmd_in_axes'] is None and params_staged['spmd_out_axes'] is None
assert params_known['spmd_in_axes'] is None is params_known['spmd_out_axes']
assert params_staged['spmd_in_axes'] is None is params_staged['spmd_out_axes']
# pruned inputs to jaxpr_known according to unks_in
donated_invars_known, _ = pe.partition_list(unks_in, params_known['donated_invars'])
@ -1085,13 +1085,15 @@ def _xmap_partial_eval_custom_params_updater(
assert len(new_params_known['in_axes']) == len(params_known['call_jaxpr'].invars)
assert len(new_params_known['out_axes']) == len(params_known['call_jaxpr'].outvars)
# added num_res new inputs to jaxpr_staged
donated_invars_staged = (*(False for _ in range(num_res)), *params_staged['donated_invars'])
# added num_res new inputs to jaxpr_staged, and pruning according to inst_in
_, donated_invars_staged = pe.partition_list(inst_in, params_staged['donated_invars'])
donated_invars_staged = [False] * num_res + donated_invars_staged
_, in_axes_staged = pe.partition_list(inst_in, params_staged['in_axes'])
in_axes_staged = [*residual_axes, *in_axes_staged]
_, out_axes_staged = pe.partition_list(kept_outs_staged, params_staged['out_axes'])
new_params_staged = dict(params_staged,
in_axes=(*residual_axes, *params_staged['in_axes']),
new_params_staged = dict(params_staged, in_axes=tuple(in_axes_staged),
out_axes=tuple(out_axes_staged),
donated_invars=donated_invars_staged)
donated_invars=tuple(donated_invars_staged))
assert len(new_params_staged['in_axes']) == len(params_staged['call_jaxpr'].invars)
assert len(new_params_staged['out_axes']) == len(params_staged['call_jaxpr'].outvars)
return new_params_known, new_params_staged

View File

@ -938,8 +938,9 @@ def _remat_partial_eval(trace, _, f, tracers, params):
if params['policy']:
# unzip into jaxpr_known and jaxpr_unknown
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = _partial_eval_jaxpr_custom(
jaxpr, in_unknowns, params['policy'])
jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
partial_eval_jaxpr_custom(jaxpr, in_unknowns, [True] * len(in_unknowns),
False, False, params['policy'])
jaxpr_known, in_used_known = dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
_, used_outs_unknown = partition_list(out_inst, out_unknowns)
jaxpr_unknown, in_used_unknown = dce_jaxpr(jaxpr_unknown, used_outs_unknown)
@ -1008,20 +1009,31 @@ def _remat_partial_eval(trace, _, f, tracers, params):
return merge_lists(out_unknowns, known_outputs, unknown_outputs)
call_partial_eval_rules[remat_call_p] = _remat_partial_eval
def _partition_knowns(pvals, unknowns: Sequence[bool]):
return ([e for e, unknown in zip(pvals, unknowns) if not unknown],
[e for e, unknown in zip(pvals, unknowns) if unknown])
def partial_eval_jaxpr_custom(
jaxpr: Jaxpr,
in_unknowns: Sequence[bool],
in_inst: Sequence[bool],
ensure_out_unknowns: Union[bool, Sequence[bool]],
ensure_out_inst: Union[bool, Sequence[bool]],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
if type(ensure_out_unknowns) is bool:
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
if type(ensure_out_inst) is bool:
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
return _partial_eval_jaxpr_custom_cached(
jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]):
assert len(known_list) + len(unknown_list) == len(which_unknown)
known_iter, unknown_iter = iter(known_list), iter(unknown_list)
return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown]
def _partial_eval_jaxpr_custom(
jaxpr: Jaxpr, in_unknowns: Sequence[bool], saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, Sequence[bool], Sequence[bool], int]:
if jaxpr.constvars: raise NotImplementedError # TODO(mattjj)
@weakref_lru_cache
def _partial_eval_jaxpr_custom_cached(
jaxpr: Jaxpr,
in_unknowns: Tuple[bool, ...],
in_inst: Tuple[bool, ...],
ensure_out_unknowns: Tuple[bool, ...],
ensure_out_inst: Tuple[bool, ...],
saveable: Callable[..., bool],
) -> Tuple[Jaxpr, Jaxpr, List[bool], List[bool], int]:
env: Dict[Var, Tuple[bool, bool]] = {}
residuals: OrderedSet[Var] = OrderedSet()
@ -1040,7 +1052,7 @@ def _partial_eval_jaxpr_custom(
return x
known_eqns, staged_eqns = [], []
map(write, in_unknowns, [True] * len(in_unknowns), jaxpr.invars)
map(write, in_unknowns, in_inst, jaxpr.invars)
for eqn in jaxpr.eqns:
unks_in, inst_in = unzip2(map(read, eqn.invars))
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
@ -1064,17 +1076,23 @@ def _partial_eval_jaxpr_custom(
out_unknowns, out_inst = unzip2(map(read, jaxpr.outvars))
assert all(type(v) is Var for v in residuals), residuals
for x, inst, ensure_inst in zip(jaxpr.outvars, out_inst, ensure_out_inst):
if ensure_inst: ensure_instantiated(inst, x)
out_unknowns = map(op.or_, out_unknowns, ensure_out_unknowns)
out_inst = map(op.or_, out_inst, ensure_out_inst)
ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
outs_known, _ = partition_list(out_unknowns, jaxpr.outvars)
known_effects = core.join_effects(*(eqn.effects for eqn in known_eqns))
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns,
known_effects)
jaxpr_known = Jaxpr(jaxpr.constvars, ins_known, [*outs_known, *residuals],
known_eqns, known_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_known)
_, ins_staged = partition_list(in_inst, jaxpr.invars)
_, outs_staged = partition_list(out_inst, jaxpr.outvars)
staged_effects = core.join_effects(*(eqn.effects for eqn in staged_eqns))
jaxpr_staged = Jaxpr((), [*residuals, *jaxpr.invars], outs_staged,
staged_eqns, staged_effects)
jaxpr_staged = Jaxpr(jaxpr.constvars, [*residuals, *ins_staged],
outs_staged, staged_eqns, staged_effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr_staged)
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, len(residuals)
@ -1105,7 +1123,8 @@ def partial_eval_jaxpr_custom_rule_not_implemented(
ParamsUpdater = Callable[[Sequence[bool], Sequence[bool], Sequence[bool],
int, dict, dict], Tuple[dict, dict]]
Sequence[bool], int, dict, dict],
Tuple[dict, dict]]
def call_partial_eval_custom_rule(
jaxpr_param_name: str, params_updater: ParamsUpdater,
@ -1114,21 +1133,21 @@ def call_partial_eval_custom_rule(
) -> Tuple[JaxprEqn, JaxprEqn, Sequence[bool], Sequence[bool], List[Var]]:
jaxpr = eqn.params[jaxpr_param_name]
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
_partial_eval_jaxpr_custom(jaxpr, unks_in, saveable)
partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
kept_outs_staged = inst_out
newvar = core.gensym([jaxpr_known, jaxpr_staged])
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
params_known, params_staged = params_updater(
unks_in, map(op.not_, unks_out), kept_outs_staged, num_res, params_known,
unks_in, inst_in, map(op.not_, unks_out), inst_out, num_res, params_known,
params_staged)
eqn_known = new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info)
eqn_staged = new_jaxpr_eqn([*residuals, *eqn.invars], out_binders_staged,
eqn_staged = new_jaxpr_eqn([*residuals, *ins_staged], out_binders_staged,
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
@ -1137,13 +1156,14 @@ def call_partial_eval_custom_rule(
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + residuals
partial_eval_jaxpr_custom_rules[core.call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, x, y: (x, y))
lambda _, __, ___, ____, _____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[core.named_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, x, y: (x, y))
lambda _, __, ___, ____, _____, x, y: (x, y))
partial_eval_jaxpr_custom_rules[remat_call_p] = \
partial(call_partial_eval_custom_rule, 'call_jaxpr',
lambda _, __, ___, ____, p1, p2: (p1, dict(p2, differentiated=True)))
lambda _, __, ___, ____, _____, p1, p2:
(p1, dict(p2, differentiated=True)))
def _jaxpr_forwarding(jaxpr: Jaxpr) -> List[Optional[int]]:

View File

@ -433,16 +433,17 @@ ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
def _xla_call_partial_eval_custom_params_updater(
unks_in: Sequence[bool],
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
new_params_known = dict(params_known, donated_invars=tuple(donated_invars_known))
donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
new_params_known = dict(params_known, donated_invars=tuple(donated_known))
# added num_res new inputs to jaxpr_staged, so extend donated_invars
donated_invars_staged = [*([False] * num_res), *params_staged['donated_invars']]
new_params_staged = dict(params_staged, donated_invars=tuple(donated_invars_staged))
_, donated_staged_ = partition_list(inst_in, params_staged['donated_invars'])
donated_staged = [False] * num_res + donated_staged_
new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged))
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',

View File

@ -4007,17 +4007,20 @@ class RematTest(jtu.JaxTestCase):
self.assertNotIn('conditional', c.as_hlo_text())
@parameterized.named_parameters(
{"testcase_name": f"_{policy_name}", "policy": policy,
"in_jaxpr2": in_jaxpr2, "not_in_jaxpr2": not_in_jaxpr2}
{"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat,
"policy": policy, "in_jaxpr2": in_jaxpr2, "not_in_jaxpr2": not_in_jaxpr2}
for remat_name, remat in [
('old_remat', api.remat),
('new_remat', new_checkpoint),
]
for policy_name, policy, in_jaxpr2, not_in_jaxpr2 in [
('save_anything', lambda *_, **__: True, [], [' sin ', ' cos ']),
('save_nothing', lambda *_, **__: False, [' sin ', ' cos '], []),
('save_sin', lambda p, *_, **__: str(p) == 'sin', [' cos '], [' sin ']),
])
def test_remat_custom_policy(self, policy, in_jaxpr2, not_in_jaxpr2):
def test_remat_custom_policy(self, remat, policy, in_jaxpr2, not_in_jaxpr2):
for square in [lambda x: x * x, api.jit(lambda x: x * x)]:
f = api.remat(lambda x: jnp.sin(square(jnp.sin(x))),
policy=policy)
f = remat(lambda x: jnp.sin(square(jnp.sin(x))), policy=policy)
y, f_lin = api.linearize(f, 1.)
ydot = f_lin(2.)
jaxpr_text = str(f_lin.func.args[0])
@ -4031,18 +4034,30 @@ class RematTest(jtu.JaxTestCase):
self.assertAllClose(ydot, ydot_expected)
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
def test_remat_custom_policy_save_cos(self):
@parameterized.named_parameters(
{"testcase_name": f"_{remat_name}", "remat": remat}
for remat_name, remat in [
('old_remat', api.remat),
('new_remat', new_checkpoint),
])
def test_remat_custom_policy_save_cos(self, remat):
save_cos = lambda prim, *_, **__: str(prim) == 'cos'
f = api.remat(lambda x: jnp.sin(jnp.sin(x)), # different function
policy=save_cos)
f = remat(lambda x: jnp.sin(jnp.sin(x)), # different function
policy=save_cos)
_, f_lin = api.linearize(f, 1.)
jaxpr_text = str(f_lin.func.args[0])
self.assertNotIn(' sin ', jaxpr_text)
self.assertNotIn(' cos ', jaxpr_text)
jtu.check_grads(f, (3.,), order=2, modes=['fwd', 'rev'])
def test_remat_checkpoint_dots(self):
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
@parameterized.named_parameters(
{"testcase_name": f"_{remat_name}", "remat": remat}
for remat_name, remat in [
('old_remat', api.remat),
('new_remat', new_checkpoint),
])
def test_remat_checkpoint_dots(self, remat):
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
@ -4058,8 +4073,14 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' dot_'), 6)
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
def test_remat_checkpoint_dots_with_no_batch_dims(self):
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
@parameterized.named_parameters(
{"testcase_name": f"_{remat_name}", "remat": remat}
for remat_name, remat in [
('old_remat', api.remat),
('new_remat', new_checkpoint),
])
def test_remat_checkpoint_dots_with_no_batch_dims(self, remat):
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
def f(x):
x = jnp.einsum('ij,jk->ik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
@ -4075,8 +4096,14 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' dot_general'), 6)
jtu.check_grads(f, (jnp.ones((2, 2)),), order=2, modes=['fwd', 'rev'])
def test_remat_checkpoint_dots_with_no_batch_dims2(self):
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
@parameterized.named_parameters(
{"testcase_name": f"_{remat_name}", "remat": remat}
for remat_name, remat in [
('old_remat', api.remat),
('new_remat', new_checkpoint),
])
def test_remat_checkpoint_dots_with_no_batch_dims2(self, remat):
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)
def f(x):
x = jnp.einsum('nij,njk->nik', x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x)
@ -4092,9 +4119,15 @@ class RematTest(jtu.JaxTestCase):
self.assertEqual(jaxpr_text.count(' dot_general'), 9)
jtu.check_grads(f, (jnp.ones((3, 2, 2)),), order=2, modes=['fwd', 'rev'])
def test_remat_checkpoint_dots_jit(self):
@parameterized.named_parameters(
{"testcase_name": f"_{remat_name}", "remat": remat}
for remat_name, remat in [
('old_remat', api.remat),
('new_remat', new_checkpoint),
])
def test_remat_checkpoint_dots_jit(self, remat):
@api.jit
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots)
def f(x):
x = jnp.dot(x, x, precision=lax.Precision.HIGHEST)
x = jnp.sin(x * 1e-3)
@ -4196,11 +4229,17 @@ class RematTest(jtu.JaxTestCase):
return lax.scan(lambda x, _: (f(x), None), x, None, length=2)[0]
jtu.check_grads(g, (3.,), order=2, modes=['rev'])
def test_remat_dropvar_policy(self):
@parameterized.named_parameters(
{"testcase_name": f"_{remat_name}", "remat": remat}
for remat_name, remat in [
('old_remat', api.remat),
('new_remat', new_checkpoint),
])
def test_remat_dropvar_policy(self, remat):
def f(x):
return x, x
@partial(api.remat, policy=jax.checkpoint_policies.checkpoint_dots)
@partial(remat, policy=jax.checkpoint_policies.checkpoint_dots)
def g(x):
x = api.grad(lambda x: f(x)[0])(x)
return x