Fix stale reference to util.prod.

Work around pytype bug. It seems that the line
from functools import cached_property
causes pytype to give up on the entire module. Avoid the member import to fix the type inference.

PiperOrigin-RevId: 513544106
This commit is contained in:
Peter Hawkins 2023-03-02 08:23:50 -08:00 committed by jax authors
parent a9421a806f
commit a002643a4a
2 changed files with 4 additions and 4 deletions

View File

@ -15,6 +15,7 @@
# On-device arrays.
from functools import partial, partialmethod
import math
import operator
from typing import (Any, List, Optional, Union)
import weakref
@ -26,7 +27,6 @@ import jax
from jax._src import core
from jax._src import abstract_arrays
from jax._src import profiler
from jax._src import util
from jax._src.config import config
from jax._src.lib import xla_client as xc
from jax._src.typing import Array
@ -163,7 +163,7 @@ class _DeviceArray(DeviceArray): # type: ignore
@property
def size(self):
return util.prod(self.aval.shape)
return math.prod(self.aval.shape)
@property
def ndim(self):

View File

@ -13,7 +13,7 @@
# limitations under the License.
import functools
from functools import partial, cached_property
from functools import partial
import itertools as it
import logging
import operator
@ -496,7 +496,7 @@ class HashableWrapper:
def _original_func(f):
if isinstance(f, property):
return cast(property, f).fget
elif isinstance(f, cached_property):
elif isinstance(f, functools.cached_property):
return f.func
return f