mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix stale reference to util.prod.
Work around pytype bug. It seems that the line from functools import cached_property causes pytype to give up on the entire module. Avoid the member import to fix the type inference. PiperOrigin-RevId: 513544106
This commit is contained in:
parent
a9421a806f
commit
a002643a4a
@ -15,6 +15,7 @@
|
||||
# On-device arrays.
|
||||
|
||||
from functools import partial, partialmethod
|
||||
import math
|
||||
import operator
|
||||
from typing import (Any, List, Optional, Union)
|
||||
import weakref
|
||||
@ -26,7 +27,6 @@ import jax
|
||||
from jax._src import core
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.typing import Array
|
||||
@ -163,7 +163,7 @@ class _DeviceArray(DeviceArray): # type: ignore
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return util.prod(self.aval.shape)
|
||||
return math.prod(self.aval.shape)
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from functools import partial, cached_property
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import logging
|
||||
import operator
|
||||
@ -496,7 +496,7 @@ class HashableWrapper:
|
||||
def _original_func(f):
|
||||
if isinstance(f, property):
|
||||
return cast(property, f).fget
|
||||
elif isinstance(f, cached_property):
|
||||
elif isinstance(f, functools.cached_property):
|
||||
return f.func
|
||||
return f
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user