mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
implement scipy.cluster.vq.vq
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
This commit is contained in:
parent
0ed29b63f0
commit
b485b8e5ce
@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.3.7...main).
|
||||
* Changes
|
||||
* {func}`jax.scipy.cluster.vq.vq` has been added.
|
||||
* `jax.experimental.maps.mesh` has been deleted.
|
||||
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
|
||||
for more information.
|
||||
|
47
jax/_src/scipy/cluster/vq.py
Normal file
47
jax/_src/scipy/cluster/vq.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import operator
|
||||
|
||||
import scipy.cluster.vq
|
||||
import textwrap
|
||||
|
||||
from jax import vmap
|
||||
from jax._src.numpy.util import _wraps, _check_arraylike, _promote_dtypes_inexact
|
||||
from jax._src.numpy.lax_numpy import argmin
|
||||
from jax._src.numpy.linalg import norm
|
||||
|
||||
|
||||
_no_chkfinite_doc = textwrap.dedent("""
|
||||
Does not support the Scipy argument ``check_finite=True``,
|
||||
because compiled JAX code cannot perform checks of array values at runtime
|
||||
""")
|
||||
|
||||
|
||||
@_wraps(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',))
|
||||
def vq(obs, code_book, check_finite=True):
|
||||
_check_arraylike("scipy.cluster.vq.vq", obs, code_book)
|
||||
if obs.ndim != code_book.ndim:
|
||||
raise ValueError("Observation and code_book should have the same rank")
|
||||
obs, code_book = _promote_dtypes_inexact(obs, code_book)
|
||||
if obs.ndim == 1:
|
||||
obs, code_book = obs[..., None], code_book[..., None]
|
||||
if obs.ndim != 2:
|
||||
raise ValueError("ndim different than 1 or 2 are not supported")
|
||||
|
||||
# explicitly rank promotion
|
||||
dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs)
|
||||
code = argmin(dist, axis=-1)
|
||||
dist_min = vmap(operator.getitem)(dist, code)
|
||||
return code, dist_min
|
@ -20,3 +20,4 @@ from jax.scipy import sparse as sparse
|
||||
from jax.scipy import special as special
|
||||
from jax.scipy import stats as stats
|
||||
from jax.scipy import fft as fft
|
||||
from jax.scipy import cluster as cluster
|
||||
|
15
jax/scipy/cluster/__init__.py
Normal file
15
jax/scipy/cluster/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.scipy.cluster import vq as vq
|
15
jax/scipy/cluster/vq.py
Normal file
15
jax/scipy/cluster/vq.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.scipy.cluster.vq import vq as vq
|
@ -24,6 +24,7 @@ from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import scipy.special as osp_special
|
||||
import scipy.cluster as osp_cluster
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
@ -31,6 +32,7 @@ from jax import lax
|
||||
from jax import scipy as jsp
|
||||
from jax._src import test_util as jtu
|
||||
from jax.scipy import special as lsp_special
|
||||
from jax.scipy import cluster as lsp_cluster
|
||||
import jax._src.scipy.eigh
|
||||
|
||||
from jax.config import config
|
||||
@ -610,6 +612,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
|
||||
self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_{jtu.format_shape_dtype_string((n_obs, n_codes, *n_feats), dtype)}",
|
||||
"n_obs": n_obs, "n_codes": n_codes, "n_feats": n_feats, "dtype": dtype}
|
||||
for n_obs in [1, 3, 5]
|
||||
for n_codes in [1, 2, 4]
|
||||
for n_feats in [()] + [(i,) for i in range(1, 3)]
|
||||
for dtype in float_dtypes + int_dtypes)) # scipy doesn't support complex
|
||||
def test_vq(self, n_obs, n_codes, n_feats, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng((n_obs, *n_feats), dtype), rng((n_codes, *n_feats), dtype)]
|
||||
self._CheckAgainstNumpy(osp_cluster.vq.vq, lsp_cluster.vq.vq, args_maker, check_dtypes=False)
|
||||
self._CompileAndCheck(lsp_cluster.vq.vq, args_maker)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user