From be8183d7464cfbeded373349b18ef3df8679573c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 9 Jan 2024 11:32:00 -0800 Subject: [PATCH] pallas: improve indexing trace time --- jax/_src/state/indexing.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/jax/_src/state/indexing.py b/jax/_src/state/indexing.py index 12c7623c2..210cd2e75 100644 --- a/jax/_src/state/indexing.py +++ b/jax/_src/state/indexing.py @@ -81,10 +81,15 @@ def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...], return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore def _maybe_concretize(x: Any): - try: - return core.concrete_or_error(None, x) - except core.ConcretizationTypeError: - return None + # This is roughly the same logic as core.concrete_or_error, but we avoid + # calling that because constructing the ConcretizationTypeError can be + # expensive as the size of the tracing context (i.e. the jaxpr) grows. + if isinstance(x, core.Tracer): + if isinstance(x.aval, core.ConcreteArray): + return x.aval.val + else: + return None + return x @tree_util.register_pytree_node_class @dataclasses.dataclass