mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13189 from Ishticode:lcm_update
PiperOrigin-RevId: 488383042
This commit is contained in:
commit
b086e73d36
@ -4367,11 +4367,12 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def lcm(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
_check_arraylike("lcm", x1, x2)
|
||||
x1, x2 = _promote_dtypes(x1, x2)
|
||||
x1, x2 = abs(x1), abs(x2)
|
||||
if not issubdtype(_dtype(x1), integer):
|
||||
raise ValueError("Arguments to jax.numpy.lcm must be integers.")
|
||||
d = gcd(x1, x2)
|
||||
return where(d == 0, _lax_const(d, 0),
|
||||
abs(multiply(x1, floor_divide(x2, d))))
|
||||
multiply(x1, floor_divide(x2, d)))
|
||||
|
||||
|
||||
@_wraps(np.extract)
|
||||
|
@ -282,6 +282,7 @@ JAX_COMPOUND_OP_RECORDS = [
|
||||
all_shapes, jtu.rand_small_positive, []),
|
||||
op_record("gcd", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
op_record("lcm", 2, int_dtypes_no_uint64, all_shapes, jtu.rand_default, []),
|
||||
op_record("lcm", 2, [np.int8], all_shapes, jtu.rand_not_small, [])
|
||||
]
|
||||
|
||||
JAX_BITWISE_OP_RECORDS = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user