mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
pallas: improve indexing trace time
This commit is contained in:
parent
e76f514b49
commit
be8183d746
@ -81,10 +81,15 @@ def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
|
||||
return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # type: ignore
|
||||
|
||||
def _maybe_concretize(x: Any):
|
||||
try:
|
||||
return core.concrete_or_error(None, x)
|
||||
except core.ConcretizationTypeError:
|
||||
return None
|
||||
# This is roughly the same logic as core.concrete_or_error, but we avoid
|
||||
# calling that because constructing the ConcretizationTypeError can be
|
||||
# expensive as the size of the tracing context (i.e. the jaxpr) grows.
|
||||
if isinstance(x, core.Tracer):
|
||||
if isinstance(x.aval, core.ConcreteArray):
|
||||
return x.aval.val
|
||||
else:
|
||||
return None
|
||||
return x
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
|
Loading…
x
Reference in New Issue
Block a user