mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #22779 from gnecula:tril
PiperOrigin-RevId: 658271885
This commit is contained in:
commit
6c083d78e6
@ -2191,20 +2191,24 @@ def min_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
||||
d1_is_constant = is_constant_dim(d1)
|
||||
if d1_is_constant and is_constant_dim(d2):
|
||||
return min(d1, d2)
|
||||
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.min_dim`")
|
||||
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.min_dim`")
|
||||
if d1_is_constant:
|
||||
return d2.rmin(d1) # type: ignore[union-attr]
|
||||
return d2.rmin(d1)
|
||||
else:
|
||||
return d1.min(d2) # type: ignore[union-attr]
|
||||
return d1.min(d2)
|
||||
|
||||
def max_dim(d1: DimSize, d2: DimSize) -> DimSize:
|
||||
"""Like max(d1, d2) but for both constant and symbolic dimensions."""
|
||||
d1_is_constant = is_constant_dim(d1)
|
||||
if d1_is_constant and is_constant_dim(d2):
|
||||
return max(d1, d2)
|
||||
d1 = concrete_dim_or_error(d1, "argument `d1` of `core.max_dim`")
|
||||
d2 = concrete_dim_or_error(d2, "argument `d2` of `core.max_dim`")
|
||||
if d1_is_constant:
|
||||
return d2.rmax(d1) # type: ignore[union-attr]
|
||||
return d2.rmax(d1)
|
||||
else:
|
||||
return d1.max(d2) # type: ignore[union-attr]
|
||||
return d1.max(d2)
|
||||
|
||||
def dimension_as_value(d: DimSize):
|
||||
"""Turns a dimension size into a JAX array.
|
||||
|
@ -84,8 +84,9 @@ T = TypeVar("T")
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
def _clip_int_to_valid_range(val: DimSize, dtype) -> int:
|
||||
def _clip_int_to_valid_range(val: DimSize, dtype, where: str) -> int:
|
||||
info = np.iinfo(dtype)
|
||||
val = core.concrete_dim_or_error(val, where)
|
||||
return core.max_dim(info.min, core.min_dim(val, info.max))
|
||||
|
||||
def _validate_shapes(shapes: Sequence[Shape]):
|
||||
@ -1350,7 +1351,8 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array:
|
||||
|
||||
def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
|
||||
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
||||
offset = _clip_int_to_valid_range(offset, np.int32)
|
||||
offset = _clip_int_to_valid_range(offset, np.int32,
|
||||
"argument `offset` of jax.numpy.eye")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
bool_eye = eq(add(broadcasted_iota(np.int32, shape, 0), np.int32(offset)),
|
||||
broadcasted_iota(np.int32, shape, 1))
|
||||
@ -1372,7 +1374,8 @@ def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
|
||||
def _tri(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array:
|
||||
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
|
||||
offset = _clip_int_to_valid_range(offset, np.int32)
|
||||
offset = _clip_int_to_valid_range(offset, np.int32,
|
||||
"argument `offset` of jax.numpy.tri")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
bool_tri = ge(add(broadcasted_iota(np.int32, shape, 0),
|
||||
asarray(core.dimension_as_value(offset)).astype(np.int32)),
|
||||
|
@ -4141,13 +4141,14 @@ def _eye(N: DimSize, M: DimSize | None = None,
|
||||
dtype: DTypeLike | None = None) -> Array:
|
||||
dtypes.check_user_dtype_supported(dtype, "eye")
|
||||
if isinstance(k, int):
|
||||
k = lax_internal._clip_int_to_valid_range(k, np.int32)
|
||||
k = lax_internal._clip_int_to_valid_range(k, np.int32,
|
||||
"`argument `k` of jax.numpy.eye")
|
||||
util.check_arraylike("eye", k)
|
||||
offset = asarray(k)
|
||||
if not (offset.shape == () and dtypes.issubdtype(offset.dtype, np.integer)):
|
||||
raise ValueError(f"k must be a scalar integer; got {k}")
|
||||
N_int = core.canonicalize_dim(N, "'N' argument of jnp.eye()")
|
||||
M_int = N_int if M is None else core.canonicalize_dim(M, "'M' argument of jnp.eye()")
|
||||
N_int = core.canonicalize_dim(N, "argument of 'N' jnp.eye()")
|
||||
M_int = N_int if M is None else core.canonicalize_dim(M, "argument 'M' of jnp.eye()")
|
||||
if N_int < 0 or M_int < 0:
|
||||
raise ValueError(f"negative dimensions are not allowed, got {N} and {M}")
|
||||
i = lax.broadcasted_iota(offset.dtype, (N_int, M_int), 0)
|
||||
|
@ -2273,6 +2273,10 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def test_tri_bug_22751(self):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, "jax.numpy.tri"):
|
||||
jax.jit(jnp.tri)(3, M=3, k=0)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=default_dtypes,
|
||||
shape=[shape for shape in all_shapes if len(shape) >= 2],
|
||||
|
@ -567,6 +567,27 @@ class DimExprTest(jtu.JaxTestCase):
|
||||
self.sampled_assertion(core.min_dim(a, 5), core.min_dim, a, 5)
|
||||
self.sampled_assertion(core.min_dim(5, a), core.min_dim, 5, a)
|
||||
|
||||
def test_min_max_type_check(self):
|
||||
a, = shape_poly.symbolic_shape("a")
|
||||
for i, f in enumerate([lambda x: core.max_dim(x, a),
|
||||
lambda x: core.max_dim(a, x),
|
||||
lambda x: core.min_dim(x, a),
|
||||
lambda x: core.min_dim(a, x)]):
|
||||
with self.subTest(f"jit_{i}"):
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, ""):
|
||||
jax.jit(f)(1)
|
||||
|
||||
arr = jnp.array([1], dtype=np.int32)
|
||||
for i, f in enumerate([lambda: core.max_dim(arr, a),
|
||||
lambda: core.max_dim(a, arr),
|
||||
lambda: core.min_dim(arr, a),
|
||||
lambda: core.min_dim(a, arr)]):
|
||||
with self.subTest(f"array_{i}"):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Only integer scalar arrays can be converted to a scalar index"):
|
||||
f()
|
||||
|
||||
def test_clamp_dim(self):
|
||||
a, b = shape_poly.symbolic_shape("a, b")
|
||||
# Clamping b <= a <= b + 10
|
||||
|
Loading…
x
Reference in New Issue
Block a user