diff --git a/CHANGELOG.md b/CHANGELOG.md index 9897b7c8d..d1386e0a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...main). * Changes + * {func}`jax.scipy.cluster.vq.vq` has been added. * `jax.experimental.maps.mesh` has been deleted. Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information. diff --git a/jax/_src/scipy/cluster/vq.py b/jax/_src/scipy/cluster/vq.py new file mode 100644 index 000000000..94de1e667 --- /dev/null +++ b/jax/_src/scipy/cluster/vq.py @@ -0,0 +1,47 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator + +import scipy.cluster.vq +import textwrap + +from jax import vmap +from jax._src.numpy.util import _wraps, _check_arraylike, _promote_dtypes_inexact +from jax._src.numpy.lax_numpy import argmin +from jax._src.numpy.linalg import norm + + +_no_chkfinite_doc = textwrap.dedent(""" +Does not support the Scipy argument ``check_finite=True``, +because compiled JAX code cannot perform checks of array values at runtime +""") + + +@_wraps(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',)) +def vq(obs, code_book, check_finite=True): + _check_arraylike("scipy.cluster.vq.vq", obs, code_book) + if obs.ndim != code_book.ndim: + raise ValueError("Observation and code_book should have the same rank") + obs, code_book = _promote_dtypes_inexact(obs, code_book) + if obs.ndim == 1: + obs, code_book = obs[..., None], code_book[..., None] + if obs.ndim != 2: + raise ValueError("ndim different than 1 or 2 are not supported") + + # explicitly rank promotion + dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs) + code = argmin(dist, axis=-1) + dist_min = vmap(operator.getitem)(dist, code) + return code, dist_min diff --git a/jax/scipy/__init__.py b/jax/scipy/__init__.py index 149311db0..0f504cf64 100644 --- a/jax/scipy/__init__.py +++ b/jax/scipy/__init__.py @@ -20,3 +20,4 @@ from jax.scipy import sparse as sparse from jax.scipy import special as special from jax.scipy import stats as stats from jax.scipy import fft as fft +from jax.scipy import cluster as cluster diff --git a/jax/scipy/cluster/__init__.py b/jax/scipy/cluster/__init__.py new file mode 100644 index 000000000..c0c3ced15 --- /dev/null +++ b/jax/scipy/cluster/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax.scipy.cluster import vq as vq diff --git a/jax/scipy/cluster/vq.py b/jax/scipy/cluster/vq.py new file mode 100644 index 000000000..0b836b09c --- /dev/null +++ b/jax/scipy/cluster/vq.py @@ -0,0 +1,15 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.scipy.cluster.vq import vq as vq diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 9b5ad7eb3..f56a2cc0b 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized import numpy as np import scipy.special as osp_special +import scipy.cluster as osp_cluster import jax from jax import numpy as jnp @@ -31,6 +32,7 @@ from jax import lax from jax import scipy as jsp from jax._src import test_util as jtu from jax.scipy import special as lsp_special +from jax.scipy import cluster as lsp_cluster import jax._src.scipy.eigh from jax.config import config @@ -610,6 +612,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase): v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(v_unitary_delta, v_eye, atol=eps) + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": f"_{jtu.format_shape_dtype_string((n_obs, n_codes, *n_feats), dtype)}", + "n_obs": n_obs, "n_codes": n_codes, "n_feats": n_feats, "dtype": dtype} + for n_obs in [1, 3, 5] + for n_codes in [1, 2, 4] + for n_feats in [()] + [(i,) for i in range(1, 3)] + for dtype in float_dtypes + int_dtypes)) # scipy doesn't support complex + def test_vq(self, n_obs, n_codes, n_feats, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng((n_obs, *n_feats), dtype), rng((n_codes, *n_feats), dtype)] + self._CheckAgainstNumpy(osp_cluster.vq.vq, lsp_cluster.vq.vq, args_maker, check_dtypes=False) + self._CompileAndCheck(lsp_cluster.vq.vq, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())