mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Implement np.gcd and np.lcm.
Taking the loop primitives out for a spin!
This commit is contained in:
parent
51f0291576
commit
7fc4e0237b
@ -87,6 +87,7 @@ jax.numpy package
|
||||
fmod
|
||||
full
|
||||
full_like
|
||||
gcd
|
||||
geomspace
|
||||
greater
|
||||
greater_equal
|
||||
@ -111,6 +112,7 @@ jax.numpy package
|
||||
issubsctype
|
||||
kaiser
|
||||
kron
|
||||
lcm
|
||||
left_shift
|
||||
less
|
||||
less_equal
|
||||
|
@ -2058,6 +2058,32 @@ hanning = onp.hanning
|
||||
kaiser = onp.kaiser # TODO: lower via lax to allow non-constant beta.
|
||||
|
||||
|
||||
@_wraps(onp.gcd)
|
||||
def gcd(x1, x2):
|
||||
if (not issubdtype(lax._dtype(x1), integer) or
|
||||
not issubdtype(lax._dtype(x2), integer)):
|
||||
raise ValueError("Arguments to gcd must be integers.")
|
||||
def cond_fn(xs):
|
||||
x1, x2 = xs
|
||||
return any(x2 != 0)
|
||||
def body_fn(xs):
|
||||
x1, x2 = xs
|
||||
x1, x2 = (where(x2 != 0, x2, x1),
|
||||
where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
|
||||
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
|
||||
x1, x2 = _promote_dtypes(lax.abs(x1), lax.abs(x2))
|
||||
x1, x2 = broadcast_arrays(x1, x2)
|
||||
gcd, _ = lax.while_loop(cond_fn, body_fn, (x1, x2))
|
||||
return gcd
|
||||
|
||||
|
||||
@_wraps(onp.lcm)
|
||||
def lcm(x1, x2):
|
||||
d = gcd(x1, x2)
|
||||
return where(d == 0, lax._const(d, 0),
|
||||
lax.div(lax.abs(multiply(x1, x2)), d))
|
||||
|
||||
|
||||
### track unimplemented functions
|
||||
|
||||
def _not_implemented(fun):
|
||||
|
@ -126,6 +126,7 @@ JAX_COMPOUND_OP_RECORDS = [
|
||||
test_name="expm1_large"),
|
||||
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
|
||||
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
|
||||
op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default(), []),
|
||||
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
|
||||
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
|
||||
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
||||
@ -134,6 +135,7 @@ JAX_COMPOUND_OP_RECORDS = [
|
||||
op_record("isclose", 2, all_dtypes, all_shapes, jtu.rand_small_positive(), []),
|
||||
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
||||
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
|
||||
op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default(), []),
|
||||
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
||||
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
|
||||
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive(), [],
|
||||
|
Loading…
x
Reference in New Issue
Block a user