jnp.eye: handle larger-than int32 offsets

This commit is contained in:
Jake VanderPlas 2023-11-09 10:23:20 -08:00
parent 8bac6d7877
commit 4dd6334265
2 changed files with 20 additions and 3 deletions

View File

@ -84,6 +84,10 @@ T = TypeVar("T")
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def _clip_int_to_valid_range(val: int, dtype) -> int:
info = np.iinfo(dtype)
return builtins.max(info.min, builtins.min(int(val), info.max))
def _validate_shapes(shapes: Sequence[Shape]):
def _check_static_shape(shape: Shape):
checked = canonicalize_shape(shape)
@ -1242,7 +1246,7 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array:
def _eye(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
offset = int(offset)
offset = _clip_int_to_valid_range(offset, np.int32)
dtype = dtypes.canonicalize_dtype(dtype)
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
broadcasted_iota(np.int32, shape, 1))
@ -1262,7 +1266,7 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
def _tri(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
offset = int(offset)
offset = _clip_int_to_valid_range(offset, np.int32)
dtype = dtypes.canonicalize_dtype(dtype)
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
broadcasted_iota(np.int32, shape, 1))

View File

@ -2025,7 +2025,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
dtype=default_dtypes,
n=[0, 4],
m=[None, 0, 1, 3, 4],
k=list(range(-4, 4)),
k=[*range(-4, 4), -2**100, 2**100],
)
def testEye(self, n, m, k, dtype):
np_fun = lambda: np.eye(n, M=m, k=k, dtype=dtype)
jnp_fun = lambda: jnp.eye(n, M=m, k=k, dtype=dtype)
args_maker = lambda: []
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype=default_dtypes,
n=[0, 4],
m=[None, 0, 1, 3, 4],
k=range(-4, 4),
)
def testTri(self, m, n, k, dtype):
np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype)