mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add custom rules for str_eqn_compact
PiperOrigin-RevId: 651911281
This commit is contained in:
parent
ea8de20d45
commit
f3c1cbc709
@ -3350,11 +3350,17 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def str_eqn_compact(primitive_name: str, params: dict) -> str:
|
||||
def str_eqn_compact(primitive: Primitive, params: dict[Any, Any]) -> str:
|
||||
"Compact equation to string conversion used in HLO metadata."
|
||||
if primitive in custom_str_eqn_compact_rules:
|
||||
return custom_str_eqn_compact_rules[primitive](primitive, params)
|
||||
primitive_name = primitive.name
|
||||
kvs = " ".join(f"{k}={v}" for k, v in params.items()
|
||||
if _compact_eqn_should_include(k, v))
|
||||
return f"{primitive_name}[{kvs}]" if len(kvs) > 0 else primitive_name
|
||||
custom_str_eqn_compact_rules: dict[
|
||||
Primitive, Callable[[Primitive, dict[Any, Any]], str]
|
||||
] = {}
|
||||
|
||||
def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext,
|
||||
settings: JaxprPpSettings) -> pp.Doc:
|
||||
|
@ -411,7 +411,7 @@ def _source_info_to_location(
|
||||
ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any],
|
||||
source_info: source_info_util.SourceInfo) -> ir.Location:
|
||||
eqn_str = (f'{source_info.name_stack}/'
|
||||
f'{core.str_eqn_compact(primitive.name, params)}')
|
||||
f'{core.str_eqn_compact(primitive, params)}')
|
||||
if config.include_full_tracebacks_in_locations.value:
|
||||
if source_info.traceback is None:
|
||||
loc = ir.Location.unknown()
|
||||
|
@ -1002,6 +1002,17 @@ def _pallas_call_lowering(
|
||||
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
|
||||
|
||||
|
||||
def _pallas_custom_str_eqn_compact(
|
||||
prim: jax_core.Primitive, params: dict[Any, Any]
|
||||
) -> str:
|
||||
del prim, params
|
||||
# Hide most info from compact str representation
|
||||
return "pallas_call"
|
||||
jax_core.custom_str_eqn_compact_rules[pallas_call_p] = (
|
||||
_pallas_custom_str_eqn_compact
|
||||
)
|
||||
|
||||
|
||||
def pallas_call(
|
||||
f: Callable[..., None],
|
||||
out_shape: Any,
|
||||
|
@ -1259,7 +1259,7 @@ def _make_op_metadata(primitive: core.Primitive,
|
||||
source_info: source_info_util.SourceInfo,
|
||||
) -> xla_client.OpMetadata:
|
||||
eqn_str = (str(source_info.name_stack) + '/'
|
||||
+ core.str_eqn_compact(primitive.name, params))
|
||||
+ core.str_eqn_compact(primitive, params))
|
||||
frame = source_info_util.user_frame(source_info)
|
||||
return xla_client.OpMetadata(
|
||||
op_type=primitive.name,
|
||||
|
Loading…
x
Reference in New Issue
Block a user