Implement np.gcd and np.lcm.

Taking the loop primitives out for a spin!
This commit is contained in:
Peter Hawkins 2019-02-19 15:57:22 -05:00
parent 51f0291576
commit 7fc4e0237b
3 changed files with 30 additions and 0 deletions

View File

@ -87,6 +87,7 @@ jax.numpy package
fmod
full
full_like
gcd
geomspace
greater
greater_equal
@ -111,6 +112,7 @@ jax.numpy package
issubsctype
kaiser
kron
lcm
left_shift
less
less_equal

View File

@ -2058,6 +2058,32 @@ hanning = onp.hanning
kaiser = onp.kaiser # TODO: lower via lax to allow non-constant beta.
@_wraps(onp.gcd)
def gcd(x1, x2):
if (not issubdtype(lax._dtype(x1), integer) or
not issubdtype(lax._dtype(x2), integer)):
raise ValueError("Arguments to gcd must be integers.")
def cond_fn(xs):
x1, x2 = xs
return any(x2 != 0)
def body_fn(xs):
x1, x2 = xs
x1, x2 = (where(x2 != 0, x2, x1),
where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
x1, x2 = _promote_dtypes(lax.abs(x1), lax.abs(x2))
x1, x2 = broadcast_arrays(x1, x2)
gcd, _ = lax.while_loop(cond_fn, body_fn, (x1, x2))
return gcd
@_wraps(onp.lcm)
def lcm(x1, x2):
d = gcd(x1, x2)
return where(d == 0, lax._const(d, 0),
lax.div(lax.abs(multiply(x1, x2)), d))
### track unimplemented functions
def _not_implemented(fun):

View File

@ -126,6 +126,7 @@ JAX_COMPOUND_OP_RECORDS = [
test_name="expm1_large"),
op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []),
op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
op_record("gcd", 2, int_dtypes, all_shapes, jtu.rand_default(), []),
op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default(), []),
op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default(), []),
@ -134,6 +135,7 @@ JAX_COMPOUND_OP_RECORDS = [
op_record("isclose", 2, all_dtypes, all_shapes, jtu.rand_small_positive(), []),
op_record("iscomplex", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
op_record("isreal", 1, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
op_record("lcm", 2, int_dtypes, all_shapes, jtu.rand_default(), []),
op_record("log2", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
op_record("log10", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
op_record("log1p", 1, number_dtypes, all_shapes, jtu.rand_positive(), [],