mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Added a flag disabling verbose error reporting
PiperOrigin-RevId: 691463398
This commit is contained in:
parent
da994d3552
commit
6283eab2ff
@ -50,6 +50,7 @@ from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.pallas.mosaic import core as tpu_core
|
||||
@ -837,6 +838,8 @@ def jaxpr_subcomp(
|
||||
except LoweringException:
|
||||
raise # We only add the extra info to the innermost exception.
|
||||
except Exception as e:
|
||||
if not pallas_call._verbose_errors_enabled():
|
||||
raise
|
||||
msg = (f"{type(e).__name__}: {e}\n" +
|
||||
"Additional diagnostics: \n" +
|
||||
f"Failing jaxpr equation: {eqn}\n")
|
||||
|
@ -40,6 +40,7 @@ from jax._src.lib.mlir.dialects import memref as memref_dialect
|
||||
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
|
||||
from jax._src.lib.mlir.dialects import scf as scf_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.pallas.mosaic_gpu import core as gpu_core
|
||||
@ -885,6 +886,8 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
except LoweringError:
|
||||
raise # We only add the extra info to the innermost exception.
|
||||
except Exception as e:
|
||||
if not pallas_call._verbose_errors_enabled():
|
||||
raise
|
||||
inval_types = map(lambda t: getattr(t, "type", None), invals)
|
||||
raise LoweringError(
|
||||
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
|
||||
|
@ -1448,6 +1448,17 @@ _PALLAS_USE_MOSAIC_GPU = config.bool_flag(
|
||||
" dialect, instead of Trition IR."
|
||||
),
|
||||
)
|
||||
_PALLAS_VERBOSE_ERRORS = config.bool_flag(
|
||||
"jax_pallas_verbose_errors",
|
||||
default=config.bool_env("JAX_PALLAS_VERBOSE_ERRORS", True),
|
||||
help=(
|
||||
"If True, print verbose error messages for Pallas kernels."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _verbose_errors_enabled() -> bool:
|
||||
return _PALLAS_VERBOSE_ERRORS.value
|
||||
|
||||
|
||||
def _unsupported_lowering_error(platform: str) -> Exception:
|
||||
|
@ -46,6 +46,7 @@ from jax._src.lib.mlir.dialects import math as math_dialect
|
||||
from jax._src.lib.mlir.dialects import scf as scf_dialect
|
||||
from jax._src.lib.triton import dialect as tt_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives
|
||||
from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.state import discharge
|
||||
@ -390,6 +391,8 @@ def lower_jaxpr_to_triton_ir(
|
||||
except LoweringError:
|
||||
raise # We only add the extra info to the innermost exception.
|
||||
except Exception as e:
|
||||
if not pallas_call._verbose_errors_enabled():
|
||||
raise
|
||||
inval_types = map(lambda t: getattr(t, "type", None), invals)
|
||||
raise LoweringError(
|
||||
f"Exception while lowering eqn:\n {eqn}\nWith context:\n "
|
||||
|
Loading…
x
Reference in New Issue
Block a user