Merge pull request #13189 from Ishticode:lcm_update

PiperOrigin-RevId: 488383042
This commit is contained in:
jax authors 2022-11-14 09:10:39 -08:00
commit b086e73d36
2 changed files with 3 additions and 1 deletions

View File

@ -4367,11 +4367,12 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
_check_arraylike("lcm", x1, x2)
x1, x2 = _promote_dtypes(x1, x2)
x1, x2 = abs(x1), abs(x2)
if not issubdtype(_dtype(x1), integer):
raise ValueError("Arguments to jax.numpy.lcm must be integers.")
d = gcd(x1, x2)
return where(d == 0, _lax_const(d, 0),
abs(multiply(x1, floor_divide(x2, d))))
multiply(x1, floor_divide(x2, d)))
@_wraps(np.extract)

View File

@ -282,6 +282,7 @@ JAX_COMPOUND_OP_RECORDS = [
all_shapes, jtu.rand_small_positive, []),
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
op_record("lcm", 2, [np.int8], all_shapes, jtu.rand_not_small, [])
]
JAX_BITWISE_OP_RECORDS = [