[pallas] Added a flag disabling verbose error reporting

PiperOrigin-RevId: 691463398
This commit is contained in:
Sergei Lebedev 2024-10-30 10:12:47 -07:00 committed by jax authors
parent da994d3552
commit 6283eab2ff
4 changed files with 20 additions and 0 deletions

View File

@ -50,6 +50,7 @@ from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic import core as tpu_core
@ -837,6 +838,8 @@ def jaxpr_subcomp(
except LoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
if not pallas_call._verbose_errors_enabled():
raise
msg = (f"{type(e).__name__}: {e}\n" +
"Additional diagnostics: \n" +
f"Failing jaxpr equation: {eqn}\n")

View File

@ -40,6 +40,7 @@ from jax._src.lib.mlir.dialects import memref as memref_dialect
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
from jax._src.lib.mlir.dialects import scf as scf_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.pallas.mosaic_gpu import core as gpu_core
@ -885,6 +886,8 @@ def lower_jaxpr_to_mosaic_gpu(
except LoweringError:
raise # We only add the extra info to the innermost exception.
except Exception as e:
if not pallas_call._verbose_errors_enabled():
raise
inval_types = map(lambda t: getattr(t, "type", None), invals)
raise LoweringError(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "

View File

@ -1448,6 +1448,17 @@ _PALLAS_USE_MOSAIC_GPU = config.bool_flag(
" dialect, instead of Trition IR."
),
)
_PALLAS_VERBOSE_ERRORS = config.bool_flag(
"jax_pallas_verbose_errors",
default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True),
help=(
"If True, print verbose error messages for Pallas kernels."
),
)
def _verbose_errors_enabled() -> bool:
return _PALLAS_VERBOSE_ERRORS.value
def _unsupported_lowering_error(platform: str) -> Exception:

View File

@ -46,6 +46,7 @@ from jax._src.lib.mlir.dialects import math as math_dialect
from jax._src.lib.mlir.dialects import scf as scf_dialect
from jax._src.lib.triton import dialect as tt_dialect
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.state import discharge
@ -390,6 +391,8 @@ def lower_jaxpr_to_triton_ir(
except LoweringError:
raise # We only add the extra info to the innermost exception.
except Exception as e:
if not pallas_call._verbose_errors_enabled():
raise
inval_types = map(lambda t: getattr(t, "type", None), invals)
raise LoweringError(
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "