mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use private names for args in api_util to avoid shadowing kwargs keys.
This is a quick fix for #25329. We probably shouldn't use kwargs in linear_util. We probably shouldn't use linear_util...
This commit is contained in:
parent
1ac6b762dd
commit
dd74394e63
@ -283,15 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
|
||||
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args
|
||||
|
||||
@lu.transformation2
|
||||
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
|
||||
def _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs):
|
||||
sentinel = object()
|
||||
args = [sentinel] * (len(fixed_args) + len(dyn_args))
|
||||
for i, arg in zip(dyn_argnums, dyn_args):
|
||||
args = [sentinel] * (len(_fixed_args) + len(dyn_args))
|
||||
for i, arg in zip(_dyn_argnums, dyn_args):
|
||||
args[i] = arg
|
||||
fixed_args_ = iter(fixed_args)
|
||||
fixed_args_ = iter(_fixed_args)
|
||||
args = [next(fixed_args_).val if x is sentinel else x for x in args]
|
||||
assert next(fixed_args_, sentinel) is sentinel
|
||||
return f(*args, **kwargs)
|
||||
return _fun(*args, **kwargs)
|
||||
|
||||
def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
||||
kwargs: dict[str, Any]):
|
||||
@ -315,9 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
|
||||
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs
|
||||
|
||||
@lu.transformation2
|
||||
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
|
||||
return f(*args, **kwargs)
|
||||
def _argnames_partial(_fun, _fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
|
||||
kwargs = dict({k: v.val for k, v in _fixed_kwargs.val.items()}, **dyn_kwargs)
|
||||
return _fun(*args, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
@ -438,9 +438,9 @@ def flat_out_axes(
|
||||
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
|
||||
ans = f(*args, **kwargs)
|
||||
spec = tree_unflatten(treedef, leaves)
|
||||
def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs):
|
||||
ans = _fun(*args, **kwargs)
|
||||
spec = tree_unflatten(_treedef, _leaves)
|
||||
try:
|
||||
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
|
||||
except ValueError:
|
||||
@ -451,7 +451,7 @@ def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
|
||||
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
|
||||
"pmapped function's output.")
|
||||
raise ValueError(msg) from None
|
||||
store.store(spec_flat)
|
||||
_store.store(spec_flat)
|
||||
return ans
|
||||
|
||||
def check_callable(fun):
|
||||
@ -687,10 +687,10 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
|
||||
for path, l in generate_key_paths(x) if l is not static)
|
||||
|
||||
@lu.transformation_with_aux2
|
||||
def result_paths(f, store, *args, **kwargs):
|
||||
def result_paths(_fun, _store, *args, **kwargs):
|
||||
"linear_util transform to get output pytree paths of pre-flattened function."
|
||||
ans = f(*args, **kwargs)
|
||||
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
ans = _fun(*args, **kwargs)
|
||||
_store.store([keystr(path) for path, _ in generate_key_paths(ans)])
|
||||
return ans
|
||||
|
||||
def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user