mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
jnp.eye: handle larger-than int32 offsets
This commit is contained in:
parent
8bac6d7877
commit
4dd6334265
@ -84,6 +84,10 @@ T = TypeVar("T")
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
def _clip_int_to_valid_range(val: int, dtype) -> int:
|
||||
info = np.iinfo(dtype)
|
||||
return builtins.max(info.min, builtins.min(int(val), info.max))
|
||||
|
||||
def _validate_shapes(shapes: Sequence[Shape]):
|
||||
def _check_static_shape(shape: Shape):
|
||||
checked = canonicalize_shape(shape)
|
||||
@ -1242,7 +1246,7 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array:
|
||||
|
||||
def _eye(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
|
||||
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
||||
offset = int(offset)
|
||||
offset = _clip_int_to_valid_range(offset, np.int32)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
||||
broadcasted_iota(np.int32, shape, 1))
|
||||
@ -1262,7 +1266,7 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
|
||||
def _tri(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
|
||||
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
|
||||
offset = int(offset)
|
||||
offset = _clip_int_to_valid_range(offset, np.int32)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
||||
broadcasted_iota(np.int32, shape, 1))
|
||||
|
@ -2025,7 +2025,20 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
dtype=default_dtypes,
|
||||
n=[0, 4],
|
||||
m=[None, 0, 1, 3, 4],
|
||||
k=list(range(-4, 4)),
|
||||
k=[*range(-4, 4), -2**100, 2**100],
|
||||
)
|
||||
def testEye(self, n, m, k, dtype):
|
||||
np_fun = lambda: np.eye(n, M=m, k=k, dtype=dtype)
|
||||
jnp_fun = lambda: jnp.eye(n, M=m, k=k, dtype=dtype)
|
||||
args_maker = lambda: []
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=default_dtypes,
|
||||
n=[0, 4],
|
||||
m=[None, 0, 1, 3, 4],
|
||||
k=range(-4, 4),
|
||||
)
|
||||
def testTri(self, m, n, k, dtype):
|
||||
np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user