From f3c1cbc70961655fce541f45183bd7ad2c841abf Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 12 Jul 2024 16:02:48 -0700 Subject: [PATCH] Add custom rules for str_eqn_compact PiperOrigin-RevId: 651911281 --- jax/_src/core.py | 8 +++++++- jax/_src/interpreters/mlir.py | 2 +- jax/_src/pallas/pallas_call.py | 11 +++++++++++ jax/experimental/jax2tf/jax2tf.py | 2 +- 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index b48ecd2a3..052573781 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -3350,11 +3350,17 @@ def _compact_eqn_should_include(k: str, v: Any) -> bool: return False return True -def str_eqn_compact(primitive_name: str, params: dict) -> str: +def str_eqn_compact(primitive: Primitive, params: dict[Any, Any]) -> str: "Compact equation to string conversion used in HLO metadata." + if primitive in custom_str_eqn_compact_rules: + return custom_str_eqn_compact_rules[primitive](primitive, params) + primitive_name = primitive.name kvs = " ".join(f"{k}={v}" for k, v in params.items() if _compact_eqn_should_include(k, v)) return f"{primitive_name}[{kvs}]" if len(kvs) > 0 else primitive_name +custom_str_eqn_compact_rules: dict[ + Primitive, Callable[[Primitive, dict[Any, Any]], str] +] = {} def pp_jaxpr_skeleton(jaxpr, eqns_fn, context: JaxprPpContext, settings: JaxprPpSettings) -> pp.Doc: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 73f81a472..f9cf76055 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -411,7 +411,7 @@ def _source_info_to_location( ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any], source_info: source_info_util.SourceInfo) -> ir.Location: eqn_str = (f'{source_info.name_stack}/' - f'{core.str_eqn_compact(primitive.name, params)}') + f'{core.str_eqn_compact(primitive, params)}') if config.include_full_tracebacks_in_locations.value: if source_info.traceback is None: loc = ir.Location.unknown() diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index db300ef32..df476c1eb 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1002,6 +1002,17 @@ def _pallas_call_lowering( mlir.register_lowering(pallas_call_p, _pallas_call_lowering) +def _pallas_custom_str_eqn_compact( + prim: jax_core.Primitive, params: dict[Any, Any] +) -> str: + del prim, params + # Hide most info from compact str representation + return "pallas_call" +jax_core.custom_str_eqn_compact_rules[pallas_call_p] = ( + _pallas_custom_str_eqn_compact +) + + def pallas_call( f: Callable[..., None], out_shape: Any, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index cfed7d7d0..bb64bdae1 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1259,7 +1259,7 @@ def _make_op_metadata(primitive: core.Primitive, source_info: source_info_util.SourceInfo, ) -> xla_client.OpMetadata: eqn_str = (str(source_info.name_stack) + '/' - + core.str_eqn_compact(primitive.name, params)) + + core.str_eqn_compact(primitive, params)) frame = source_info_util.user_frame(source_info) return xla_client.OpMetadata( op_type=primitive.name,