mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Copybara import of the project:
-- f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 by George Necula <gcnecula@gmail.com>: [export] Fix A user reported an error when trying to export a function that has a "lower" attribute (to impersonate a jitted function) but does not have a "__name__" attribute. The solution is to use the default name "<unnamed function>". While I was at it I have added a `util.fun_name` to get the name of a Callable, and I use it in several places. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21572 from gnecula:exp_fix_name f79d1060cccf7c9a1c02d0bcab06c6ee0ef795a8 PiperOrigin-RevId: 639236990
This commit is contained in:
parent
432159a9d3
commit
be1e40dc2e
@ -62,7 +62,7 @@ from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
|
||||
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
|
||||
as_hashable_function, distributed_debug_log,
|
||||
tuple_insert, moveaxis, split_list, wrap_name,
|
||||
merge_lists, partition_list)
|
||||
merge_lists, partition_list, fun_name)
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -577,7 +577,7 @@ def xmap(fun: Callable,
|
||||
in_axes_flat, args_flat)
|
||||
|
||||
params = dict(
|
||||
name=getattr(fun, '__name__', '<unnamed function>'),
|
||||
name=fun_name(fun),
|
||||
in_axes=tuple(in_axes_flat),
|
||||
out_axes_thunk=out_axes_thunk,
|
||||
donated_invars=donated_invars,
|
||||
|
@ -79,7 +79,7 @@ from jax._src.tree_util import (
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wraps,
|
||||
distributed_debug_log, split_list, weakref_lru_cache,
|
||||
merge_lists, flatten, unflatten, subs_list)
|
||||
merge_lists, flatten, unflatten, subs_list, fun_name)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -339,7 +339,7 @@ def _cpp_pjit(jit_info: PjitInfo):
|
||||
|
||||
fun = jit_info.fun
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
getattr(fun, "__name__", "<unnamed function>"),
|
||||
fun_name(fun),
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry,
|
||||
pxla.shard_arg,
|
||||
@ -652,7 +652,7 @@ def _infer_params(jit_info, args, kwargs):
|
||||
out_layouts=out_layouts_flat,
|
||||
resource_env=resource_env,
|
||||
donated_invars=donated_invars,
|
||||
name=getattr(flat_fun, '__name__', '<unknown>'),
|
||||
name=fun_name(flat_fun),
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
)
|
||||
|
@ -616,7 +616,7 @@ class Lowered(Stage):
|
||||
args_info, # PyTree of ArgInfo
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False,
|
||||
fun_name: str = "unknown",
|
||||
fun_name: str = "<unnamed function>",
|
||||
jaxpr: core.ClosedJaxpr | None = None):
|
||||
|
||||
self._lowering = lowering
|
||||
@ -634,7 +634,7 @@ class Lowered(Stage):
|
||||
donate_argnums: tuple[int, ...],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False,
|
||||
fun_name: str = "unknown",
|
||||
fun_name: str = "<unnamed function>",
|
||||
jaxpr: core.ClosedJaxpr | None = None):
|
||||
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
|
||||
|
||||
|
@ -360,6 +360,9 @@ class WrapKwArgs:
|
||||
def wrap_name(name, transform_name):
|
||||
return transform_name + '(' + name + ')'
|
||||
|
||||
def fun_name(fun: Callable):
|
||||
return getattr(fun, "__name__", "<unnamed function>")
|
||||
|
||||
def canonicalize_axis(axis, num_dims) -> int:
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = operator.index(axis)
|
||||
@ -399,7 +402,7 @@ def wraps(
|
||||
"""
|
||||
def wrapper(fun: T) -> T:
|
||||
try:
|
||||
name = getattr(wrapped, "__name__", "<unnamed function>")
|
||||
name = fun_name(wrapped)
|
||||
doc = getattr(wrapped, "__doc__", "") or ""
|
||||
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
||||
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
||||
|
@ -405,7 +405,7 @@ def export(fun_jax: Callable,
|
||||
symbolic_scope = (d.scope, k_path)
|
||||
continue
|
||||
symbolic_scope[0]._check_same_scope(
|
||||
d, when=f"when exporting {getattr(wrapped_fun_jax, '__name__')}",
|
||||
d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}",
|
||||
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
|
||||
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))
|
||||
|
||||
|
@ -538,6 +538,25 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
|
||||
get_exported(f)(x_poly_spec, y_poly_spec)
|
||||
|
||||
def test_poly_export_callable_with_no_name(self):
|
||||
# This was reported by a user
|
||||
class MyCallable:
|
||||
def __call__(self, x):
|
||||
return jnp.sin(x)
|
||||
|
||||
# This makes it look like a jitted-function
|
||||
def lower(self, x,
|
||||
_experimental_lowering_parameters=None):
|
||||
return jax.jit(self.__call__).lower(
|
||||
x,
|
||||
_experimental_lowering_parameters=_experimental_lowering_parameters)
|
||||
|
||||
a, = export.symbolic_shape("a,")
|
||||
# No error
|
||||
_ = get_exported(MyCallable())(
|
||||
jax.ShapeDtypeStruct((a, a), dtype=np.float32)
|
||||
)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(v=v)
|
||||
|
Loading…
x
Reference in New Issue
Block a user