Add custom rules for str_eqn_compact

PiperOrigin-RevId: 651911281
This commit is contained in:
Sharad Vikram 2024-07-12 16:02:48 -07:00 committed by jax authors
parent ea8de20d45
commit f3c1cbc709
4 changed files with 20 additions and 3 deletions

View File

@ -3350,11 +3350,17 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool:
return False
return True
def str_eqn_compact(primitive_name: str, params: dict) -> str:
def str_eqn_compact(primitive: Primitive, params: dict[Any, Any]) -> str:
"Compact equation to string conversion used in HLO metadata."
if primitive in custom_str_eqn_compact_rules:
return custom_str_eqn_compact_rules[primitive](primitive, params)
primitive_name = primitive.name
kvs = " ".join(f"{k}={v}" for k, v in params.items()
if _compact_eqn_should_include(k, v))
return f"{primitive_name}[{kvs}]" if len(kvs) > 0 else primitive_name
custom_str_eqn_compact_rules: dict[
Primitive, Callable[[Primitive, dict[Any, Any]], str]
] = {}
def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
settings: JaxprPpSettings) -> pp.Doc:

View File

@ -411,7 +411,7 @@ def _source_info_to_location(
ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
source_info: source_info_util.SourceInfo) -> ir.Location:
eqn_str = (f'{source_info.name_stack}/'
f'{core.str_eqn_compact(primitive.name, params)}')
f'{core.str_eqn_compact(primitive, params)}')
if config.include_full_tracebacks_in_locations.value:
if source_info.traceback is None:
loc = ir.Location.unknown()

View File

@ -1002,6 +1002,17 @@ def _pallas_call_lowering(
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
def _pallas_custom_str_eqn_compact(
prim: jax_core.Primitive, params: dict[Any, Any]
) -> str:
del prim, params
# Hide most info from compact str representation
return "pallas_call"
jax_core.custom_str_eqn_compact_rules[pallas_call_p] = (
_pallas_custom_str_eqn_compact
)
def pallas_call(
f: Callable[..., None],
out_shape: Any,

View File

@ -1259,7 +1259,7 @@ def _make_op_metadata(primitive: core.Primitive,
source_info: source_info_util.SourceInfo,
) -> xla_client.OpMetadata:
eqn_str = (str(source_info.name_stack) + '/'
+ core.str_eqn_compact(primitive.name, params))
+ core.str_eqn_compact(primitive, params))
frame = source_info_util.user_frame(source_info)
return xla_client.OpMetadata(
op_type=primitive.name,