pallas: improve indexing trace time

This commit is contained in:
Jake VanderPlas 2024-01-09 11:32:00 -08:00
parent e76f514b49
commit be8183d746

View File

@ -81,10 +81,15 @@ def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
def _maybe_concretize(x: Any):
try:
return core.concrete_or_error(None, x)
except core.ConcretizationTypeError:
return None
# This is roughly the same logic as core.concrete_or_error, but we avoid
# calling that because constructing the ConcretizationTypeError can be
# expensive as the size of the tracing context (i.e. the jaxpr) grows.
if isinstance(x, core.Tracer):
if isinstance(x.aval, core.ConcreteArray):
return x.aval.val
else:
return None
return x
@tree_util.register_pytree_node_class
@dataclasses.dataclass