mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Correct norm in ann.py doc.
PiperOrigin-RevId: 487814084
This commit is contained in:
parent
ce85106578
commit
10e6fe8cde
@ -188,10 +188,10 @@ def approx_min_k(operand: Array,
|
||||
>>>
|
||||
>>> qy = jax.numpy.array(np.random.rand(50, 64))
|
||||
>>> db = jax.numpy.array(np.random.rand(1024, 64))
|
||||
>>> half_db_norms = jax.numpy.linalg.norm(db, axis=1) / 2
|
||||
>>> dists, neighbors = l2_ann(qy, db, half_db_norms, k=10)
|
||||
>>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2
|
||||
>>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)
|
||||
|
||||
In the example above, we compute ``db_norms/2 - dot(qy, db^T)`` instead of
|
||||
In the example above, we compute ``db^2/2 - dot(qy, db^T)`` instead of
|
||||
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
|
||||
arithmetics and produces the same set of neighbors.
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user