Copybara import of the project:

--
f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 by George Necula <gcnecula@gmail.com>:

[export] Fix

A user reported an error when trying to export a function
that has a "lower" attribute (to impersonate a jitted function)
but does not have a "__name__" attribute.
The solution is to use the default name "<unnamed function>".

While I was at it I have added a `util.fun_name` to get
the name of a Callable, and I use it in several places.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21572 from gnecula:exp_fix_name f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8
PiperOrigin-RevId: 639236990
This commit is contained in:
George Necula 2024-05-31 20:38:16 -07:00 committed by jax authors
parent 432159a9d3
commit be1e40dc2e
6 changed files with 31 additions and 9 deletions

View File

@ -62,7 +62,7 @@ from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
as_hashable_function, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name,
merge_lists, partition_list)
merge_lists, partition_list, fun_name)
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
@ -577,7 +577,7 @@ def xmap(fun: Callable,
in_axes_flat, args_flat)
params = dict(
name=getattr(fun, '__name__', '<unnamed function>'),
name=fun_name(fun),
in_axes=tuple(in_axes_flat),
out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars,

View File

@ -79,7 +79,7 @@ from jax._src.tree_util import (
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
merge_lists, flatten, unflatten, subs_list)
merge_lists, flatten, unflatten, subs_list, fun_name)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -339,7 +339,7 @@ def _cpp_pjit(jit_info: PjitInfo):
fun = jit_info.fun
cpp_pjit_f = xc._xla.pjit(
getattr(fun, "__name__", "<unnamed function>"),
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg,
@ -652,7 +652,7 @@ def _infer_params(jit_info, args, kwargs):
out_layouts=out_layouts_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unknown>'),
name=fun_name(flat_fun),
keep_unused=keep_unused,
inline=inline,
)

View File

@ -616,7 +616,7 @@ class Lowered(Stage):
args_info, # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False,
fun_name: str = "unknown",
fun_name: str = "<unnamed function>",
jaxpr: core.ClosedJaxpr | None = None):
self._lowering = lowering
@ -634,7 +634,7 @@ class Lowered(Stage):
donate_argnums: tuple[int, ...],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False,
fun_name: str = "unknown",
fun_name: str = "<unnamed function>",
jaxpr: core.ClosedJaxpr | None = None):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.

View File

@ -360,6 +360,9 @@ class WrapKwArgs:
def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'
def fun_name(fun: Callable):
return getattr(fun, "__name__", "<unnamed function>")
def canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
@ -399,7 +402,7 @@ def wraps(
"""
def wrapper(fun: T) -> T:
try:
name = getattr(wrapped, "__name__", "<unnamed function>")
name = fun_name(wrapped)
doc = getattr(wrapped, "__doc__", "") or ""
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
fun.__annotations__ = getattr(wrapped, "__annotations__", {})

View File

@ -405,7 +405,7 @@ def export(fun_jax: Callable,
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {getattr(wrapped_fun_jax, '__name__')}",
d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}",
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))

View File

@ -538,6 +538,25 @@ class JaxExportTest(jtu.JaxTestCase):
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
get_exported(f)(x_poly_spec, y_poly_spec)
def test_poly_export_callable_with_no_name(self):
# This was reported by a user
class MyCallable:
def __call__(self, x):
return jnp.sin(x)
# This makes it look like a jitted-function
def lower(self, x,
_experimental_lowering_parameters=None):
return jax.jit(self.__call__).lower(
x,
_experimental_lowering_parameters=_experimental_lowering_parameters)
a, = export.symbolic_shape("a,")
# No error
_ = get_exported(MyCallable())(
jax.ShapeDtypeStruct((a, a), dtype=np.float32)
)
@jtu.parameterized_filterable(
kwargs=[
dict(v=v)