Merge pull request #22779 from gnecula:tril

PiperOrigin-RevId: 658271885
This commit is contained in:
jax authors 2024-07-31 22:47:25 -07:00
commit 6c083d78e6
5 changed files with 43 additions and 10 deletions

View File

@ -2191,20 +2191,24 @@ def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return min(d1, d2)
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.min_dim`")
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.min_dim`")
if d1_is_constant:
return d2.rmin(d1) # type: ignore[union-attr]
return d2.rmin(d1)
else:
return d1.min(d2) # type: ignore[union-attr]
return d1.min(d2)
def max_dim(d1: DimSize, d2: DimSize) -> DimSize:
"""Like max(d1, d2) but for both constant and symbolic dimensions."""
d1_is_constant = is_constant_dim(d1)
if d1_is_constant and is_constant_dim(d2):
return max(d1, d2)
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.max_dim`")
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.max_dim`")
if d1_is_constant:
return d2.rmax(d1) # type: ignore[union-attr]
return d2.rmax(d1)
else:
return d1.max(d2) # type: ignore[union-attr]
return d1.max(d2)
def dimension_as_value(d: DimSize):
"""Turns a dimension size into a JAX array.

View File

@ -84,8 +84,9 @@ T = TypeVar("T")
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
def _clip_int_to_valid_range(val: DimSize, dtype) -> int:
def _clip_int_to_valid_range(val: DimSize, dtype, where: str) -> int:
info = np.iinfo(dtype)
val = core.concrete_dim_or_error(val, where)
return core.max_dim(info.min, core.min_dim(val, info.max))
def _validate_shapes(shapes: Sequence[Shape]):
@ -1350,7 +1351,8 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array:
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
offset = _clip_int_to_valid_range(offset, np.int32)
offset = _clip_int_to_valid_range(offset, np.int32,
"argument `offset` of jax.numpy.eye")
dtype = dtypes.canonicalize_dtype(dtype)
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
broadcasted_iota(np.int32, shape, 1))
@ -1372,7 +1374,8 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
offset = _clip_int_to_valid_range(offset, np.int32)
offset = _clip_int_to_valid_range(offset, np.int32,
"argument `offset` of jax.numpy.tri")
dtype = dtypes.canonicalize_dtype(dtype)
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0),
asarray(core.dimension_as_value(offset)).astype(np.int32)),

View File

@ -4141,13 +4141,14 @@ def _eye(N: DimSize, M: DimSize | None = None,
dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "eye")
if isinstance(k, int):
k = lax_internal._clip_int_to_valid_range(k, np.int32)
k = lax_internal._clip_int_to_valid_range(k, np.int32,
"`argument `k` of jax.numpy.eye")
util.check_arraylike("eye", k)
offset = asarray(k)
if not (offset.shape == () and dtypes.issubdtype(offset.dtype, np.integer)):
raise ValueError(f"k must be a scalar integer; got {k}")
N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()")
M_int = N_int if M is None else core.canonicalize_dim(M, "'M' argument of jnp.eye()")
N_int = core.canonicalize_dim(N, "argument of 'N' jnp.eye()")
M_int = N_int if M is None else core.canonicalize_dim(M, "argument 'M' of jnp.eye()")
if N_int < 0 or M_int < 0:
raise ValueError(f"negative dimensions are not allowed, got {N} and {M}")
i = lax.broadcasted_iota(offset.dtype, (N_int, M_int), 0)

View File

@ -2273,6 +2273,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def test_tri_bug_22751(self):
with self.assertRaisesRegex(core.ConcretizationTypeError, "jax.numpy.tri"):
jax.jit(jnp.tri)(3, M=3, k=0)
@jtu.sample_product(
dtype=default_dtypes,
shape=[shape for shape in all_shapes if len(shape) >= 2],

View File

@ -567,6 +567,27 @@ class DimExprTest(jtu.JaxTestCase):
self.sampled_assertion(core.min_dim(a, 5), core.min_dim, a, 5)
self.sampled_assertion(core.min_dim(5, a), core.min_dim, 5, a)
def test_min_max_type_check(self):
a, = shape_poly.symbolic_shape("a")
for i, f in enumerate([lambda x: core.max_dim(x, a),
lambda x: core.max_dim(a, x),
lambda x: core.min_dim(x, a),
lambda x: core.min_dim(a, x)]):
with self.subTest(f"jit_{i}"):
with self.assertRaisesRegex(core.ConcretizationTypeError, ""):
jax.jit(f)(1)
arr = jnp.array([1], dtype=np.int32)
for i, f in enumerate([lambda: core.max_dim(arr, a),
lambda: core.max_dim(a, arr),
lambda: core.min_dim(arr, a),
lambda: core.min_dim(a, arr)]):
with self.subTest(f"array_{i}"):
with self.assertRaisesRegex(
TypeError,
"Only integer scalar arrays can be converted to a scalar index"):
f()
def test_clamp_dim(self):
a, b = shape_poly.symbolic_shape("a, b")
# Clamping b <= a <= b + 10