Correct norm in ann.py doc.

PiperOrigin-RevId: 487814084
This commit is contained in:
Felix Chern 2022-11-11 07:08:27 -08:00 committed by jax authors
parent ce85106578
commit 10e6fe8cde

View File

@ -188,10 +188,10 @@ def approx_min_k(operand: Array,
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> half_db_norms = jax.numpy.linalg.norm(db, axis=1) / 2
>>> dists, neighbors = l2_ann(qy, db, half_db_norms, k=10)
>>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2
>>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)
In the example above, we compute ``db_norms/2 - dot(qy, db^T)`` instead of
In the example above, we compute ``db^2/2 - dot(qy, db^T)`` instead of
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
arithmetics and produces the same set of neighbors.
"""